{ "cells": [ { "cell_type": "markdown", "id": "2a31ca29", "metadata": {}, "source": [ "LeNet5 architecture\n", "\n", "(Using tanh, avg pool, no Gaussian connections, and complete C3 connections)" ] }, { "cell_type": "code", "execution_count": 8, "id": "3840aa28", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0th epoch starting.\n", "1th epoch starting.\n", "2th epoch starting.\n", "3th epoch starting.\n", "4th epoch starting.\n", "5th epoch starting.\n", "6th epoch starting.\n", "7th epoch starting.\n", "8th epoch starting.\n", "9th epoch starting.\n", "Time ellapsed in training is: 55.27556610107422\n", "[Test set] Average loss: 0.0002, Accuracy: 9307/10000 (93.07%)\n", "\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.optim import Optimizer\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets\n", "from torchvision.transforms import transforms\n", "\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "# device = \"cpu\"\n", "\n", "'''\n", "Step 1:\n", "'''\n", "\n", "# MNIST dataset\n", "train_dataset = datasets.MNIST(root='./mnist_data/',\n", " train=True, \n", " transform=transforms.ToTensor(),\n", " download=True)\n", "\n", "test_dataset = datasets.MNIST(root='./mnist_data/',\n", " train=False, \n", " transform=transforms.ToTensor())\n", "\n", "\n", "'''\n", "Step 2: LeNet5\n", "'''\n", "class LeNet(nn.Module) :\n", " \n", " def __init__(self) :\n", " super(LeNet, self).__init__()\n", " \n", " #padding=2 makes 28x28 image into 32x32\n", " self.conv_layer1 = nn.Sequential(\n", " nn.Conv2d(1, 6, kernel_size=5, padding=2),\n", " nn.Tanh()\n", " )\n", " self.pool_layer1 = nn.Sequential(\n", " nn.AvgPool2d(kernel_size=2, stride=2),\n", " nn.Tanh()\n", " )\n", " self.conv_layer2 = nn.Sequential(\n", " nn.Conv2d(6, 16, kernel_size=5),\n", " nn.Tanh()\n", " )\n", " self.pool_layer2 = nn.Sequential(\n", " nn.AvgPool2d(kernel_size=2, stride=2),\n", " nn.Tanh()\n", " )\n", "# self.conv_layer3 = nn.Sequential(\n", "# nn.Conv2d(16, 120, kernel_size=5),\n", "# nn.Tanh()\n", "# )\n", " self.C5_layer = nn.Sequential(\n", " nn.Linear(5*5*16, 120),\n", " nn.Tanh()\n", " )\n", " self.fc_layer1 = nn.Sequential(\n", " nn.Linear(120, 84),\n", " nn.Tanh()\n", " )\n", " self.fc_layer2 = nn.Linear(84, 10)\n", " \n", " \n", " def forward(self, x) :\n", " output = self.conv_layer1(x)\n", " output = self.pool_layer1(output)\n", " output = self.conv_layer2(output)\n", " output = self.pool_layer2(output)\n", "# output = self.conv_layer3(output)\n", "# output = output.view(-1, 120)\n", " output = output.view(-1,5*5*16)\n", " output = self.C5_layer(output)\n", " output = self.fc_layer1(output)\n", " output = self.fc_layer2(output)\n", " return output\n", "\n", " \n", "'''\n", "Step 3\n", "'''\n", "model = LeNet().to(device)\n", "loss_function = torch.nn.CrossEntropyLoss()\n", "optimizer = torch.optim.SGD(model.parameters(), lr=1e-1) # lr hand-tuned\n", "\n", "'''\n", "Step 4\n", "'''\n", "train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=1024, shuffle=True)\n", "\n", "import time\n", "start = time.time()\n", "for epoch in range(10) :\n", " print(f\"{epoch}th epoch starting.\")\n", " for images, labels in train_loader :\n", " images, labels = images.to(device), labels.to(device)\n", " \n", " optimizer.zero_grad()\n", " train_loss = loss_function(model(images), labels)\n", " train_loss.backward()\n", "\n", " optimizer.step()\n", "end = time.time()\n", "print(\"Time ellapsed in training is: {}\".format(end - start))\n", "\n", "\n", "'''\n", "Step 5\n", "'''\n", "test_loss, correct, total = 0, 0, 0\n", "\n", "test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1024, shuffle=False)\n", "\n", "for images, labels in test_loader :\n", " images, labels = images.to(device), labels.to(device)\n", "\n", " output = model(images)\n", " test_loss += loss_function(output, labels).item()\n", "\n", " pred = output.max(1, keepdim=True)[1]\n", " correct += pred.eq(labels.view_as(pred)).sum().item()\n", " \n", " total += labels.size(0)\n", " \n", "print('[Test set] Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\\n'.format(\n", " test_loss /total, correct, total,\n", " 100. * correct / total))" ] }, { "cell_type": "markdown", "id": "0f63baaf", "metadata": {}, "source": [ "Modern variant of LeNet5\n", "\n", "(Using ReLU, max pool, no Gaussian connections, and complete C3 connections)" ] }, { "cell_type": "code", "execution_count": 9, "id": "ee82c663", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0th epoch starting.\n", "1th epoch starting.\n", "2th epoch starting.\n", "3th epoch starting.\n", "4th epoch starting.\n", "5th epoch starting.\n", "6th epoch starting.\n", "7th epoch starting.\n", "8th epoch starting.\n", "9th epoch starting.\n", "Time ellapsed in training is: 54.80843114852905\n", "[Test set] Average loss: 0.0001, Accuracy: 9695/10000 (96.95%)\n", "\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.optim import Optimizer\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets\n", "from torchvision.transforms import transforms\n", "\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "'''\n", "Step 1:\n", "'''\n", "\n", "# MNIST dataset\n", "train_dataset = datasets.MNIST(root='./mnist_data/',\n", " train=True, \n", " transform=transforms.ToTensor(),\n", " download=True)\n", "\n", "test_dataset = datasets.MNIST(root='./mnist_data/',\n", " train=False, \n", " transform=transforms.ToTensor())\n", "\n", "'''\n", "Step 2\n", "'''\n", "class LeNet(nn.Module) :\n", " \n", " def __init__(self) :\n", " super(LeNet, self).__init__()\n", " \n", " #padding=2 makes 28x28 image into 32x32\n", " self.conv_layer1 = nn.Sequential(\n", " nn.Conv2d(1, 6, kernel_size=5, padding=2),\n", " nn.ReLU()\n", " )\n", " self.pool_layer1 = nn.MaxPool2d(kernel_size=2, stride=2)\n", " self.conv_layer2 = nn.Sequential(\n", " nn.Conv2d(6, 16, kernel_size=5),\n", " nn.ReLU()\n", " )\n", " self.pool_layer2 = nn.MaxPool2d(kernel_size=2, stride=2) \n", " self.C5_layer = nn.Sequential(\n", " nn.Linear(5*5*16, 120),\n", " nn.ReLU()\n", " )\n", " self.fc_layer1 = nn.Sequential(\n", " nn.Linear(120, 84),\n", " nn.ReLU()\n", " )\n", " self.fc_layer2 = nn.Linear(84, 10)\n", " \n", " \n", " def forward(self, x) :\n", " output = self.conv_layer1(x)\n", " output = self.pool_layer1(output)\n", " output = self.conv_layer2(output)\n", " output = self.pool_layer2(output)\n", " output = output.view(-1,5*5*16)\n", " output = self.C5_layer(output)\n", " output = self.fc_layer1(output)\n", " output = self.fc_layer2(output)\n", " return output\n", "\n", " \n", "'''\n", "Step 3\n", "'''\n", "model = LeNet().to(device)\n", "loss_function = torch.nn.CrossEntropyLoss()\n", "optimizer = torch.optim.SGD(model.parameters(), lr=1e-1) # lr hand-tuned\n", "\n", "'''\n", "Step 4\n", "'''\n", "train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=1024, shuffle=True)\n", "\n", "import time\n", "start = time.time()\n", "for epoch in range(10) :\n", " print(f\"{epoch}th epoch starting.\")\n", " for images, labels in train_loader :\n", " images, labels = images.to(device), labels.to(device)\n", " \n", " optimizer.zero_grad()\n", " train_loss = loss_function(model(images), labels)\n", " train_loss.backward()\n", "\n", " optimizer.step()\n", "end = time.time()\n", "print(\"Time ellapsed in training is: {}\".format(end - start))\n", "\n", "\n", "'''\n", "Step 5\n", "'''\n", "test_loss, correct, total = 0, 0, 0\n", "\n", "test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1024, shuffle=False)\n", "\n", "for images, labels in test_loader :\n", " images, labels = images.to(device), labels.to(device)\n", "\n", " output = model(images)\n", " test_loss += loss_function(output, labels).item()\n", "\n", " pred = output.max(1, keepdim=True)[1]\n", " correct += pred.eq(labels.view_as(pred)).sum().item()\n", " \n", " total += labels.size(0)\n", " \n", "print('[Test set] Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\\n'.format(\n", " test_loss /total, correct, total,\n", " 100. * correct / total))" ] }, { "cell_type": "markdown", "id": "6d9d8d82", "metadata": {}, "source": [ "\n", "Classical LeNet5 on CIFAR10\n", "\n", "Result: \n", "[Test set] Average loss: 0.0128, Accuracy: 5983/10000 (59.83%)" ] }, { "cell_type": "code", "execution_count": 10, "id": "ea406a0b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar_10data/cifar-10-python.tar.gz\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "99af327cd2cc48bb83f203440bcd91b8", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/170498071 [00:00" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets\n", "from torchvision.transforms import transforms\n", "import matplotlib.pyplot as plt\n", "\n", "\n", "\n", "# Image preprocessing modules\n", "# transform = transforms.Compose([\n", "# transforms.Pad(4),\n", "# transforms.RandomHorizontalFlip(),\n", "# transforms.RandomCrop(32),\n", "# transforms.ToTensor()])\n", "\n", "# transform = transforms.Compose([\n", "# transforms.ColorJitter(brightness=0.8),\n", "# transforms.ToTensor()])\n", "\n", "transform = transforms.Compose([\n", " transforms.RandomAffine(30),\n", " transforms.ToTensor()])\n", "\n", "# transform = transforms.Compose([\n", "# transforms.RandomPerspective(),\n", "# transforms.ToTensor()])\n", "\n", "\n", "# transform = transforms.Compose([\n", "# transforms.RandomResizedCrop(28),\n", "# transforms.ToTensor()])\n", "\n", "# transform = transforms.Compose([\n", "# transforms.GaussianBlur(3),\n", "# transforms.ToTensor()])\n", "\n", "# transform = transforms.Compose([\n", "# transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value='random'),\n", "# transforms.ToTensor()])\n", "\n", "\n", "# Number of imeages to display\n", "B = 25\n", "\n", "train_dataset = datasets.CIFAR10(root='./cifar_10data/',\n", " transform=transform,\n", " download=True)\n", "\n", "\n", "train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=B, shuffle=False)\n", "\n", "\n", "fig = plt.figure(figsize=(15, 10))\n", "fig.suptitle('Augmented data', fontsize=16)\n", "images, _ = next(iter(train_loader)) # discard label\n", "for k in range(B) :\n", " ax = fig.add_subplot(B//5, 5, k+1)\n", "\n", " plt.imshow(images[k,:,:,:].squeeze().permute(1, 2, 0))\n", " plt.axis('off')\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "ae3a2d35", "metadata": {}, "source": [ "Classical LeNet5 on CIFAR10 with data augmentation\n", "\n", "Result: \n", "[Test set] Average loss: 0.0107, Accuracy: 6196/10000 (61.96%)" ] }, { "cell_type": "code", "execution_count": 11, "id": "ea21fa03", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n" ] }, { "ename": "RuntimeError", "evalue": "CUDA error: invalid device ordinal\nCUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1.", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)", "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 78\u001b[0m \u001b[0mStep\u001b[0m \u001b[1;36m3\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 79\u001b[0m '''\n\u001b[1;32m---> 80\u001b[1;33m \u001b[0mmodel\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mLeNet\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 81\u001b[0m \u001b[0mloss_function\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mCrossEntropyLoss\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 82\u001b[0m \u001b[0moptimizer\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0moptim\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mSGD\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlr\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m1e-2\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36mto\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 850\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdtype\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mt\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mis_floating_point\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0mt\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mis_complex\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32melse\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 851\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 852\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mconvert\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 853\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 854\u001b[0m def register_backward_hook(\n", "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_apply\u001b[1;34m(self, fn)\u001b[0m\n\u001b[0;32m 528\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 529\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 530\u001b[1;33m \u001b[0mmodule\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 531\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 532\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_apply\u001b[1;34m(self, fn)\u001b[0m\n\u001b[0;32m 528\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 529\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 530\u001b[1;33m \u001b[0mmodule\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 531\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 532\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_apply\u001b[1;34m(self, fn)\u001b[0m\n\u001b[0;32m 550\u001b[0m \u001b[1;31m# `with torch.no_grad():`\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 551\u001b[0m \u001b[1;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 552\u001b[1;33m \u001b[0mparam_applied\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 553\u001b[0m \u001b[0mshould_use_set_data\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mparam_applied\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 554\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mshould_use_set_data\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m~\\anaconda3\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36mconvert\u001b[1;34m(t)\u001b[0m\n\u001b[0;32m 848\u001b[0m return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None,\n\u001b[0;32m 849\u001b[0m non_blocking, memory_format=convert_to_format)\n\u001b[1;32m--> 850\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdtype\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mt\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mis_floating_point\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0mt\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mis_complex\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32melse\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 851\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 852\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mconvert\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;31mRuntimeError\u001b[0m: CUDA error: invalid device ordinal\nCUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1." ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.optim import Optimizer\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets\n", "from torchvision.transforms import transforms\n", "\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "'''\n", "Step 1:\n", "'''\n", "\n", "# Image preprocessing modules\n", "transform = transforms.Compose([\n", " transforms.Pad(4),\n", " transforms.RandomHorizontalFlip(),\n", " transforms.RandomCrop(32),\n", " transforms.ToTensor()])\n", "\n", "train_dataset = datasets.CIFAR10(root='./cifar_10data/',\n", " train=True, \n", " transform=transform,\n", " download=True)\n", "\n", "test_dataset = datasets.CIFAR10(root='./cifar_10data/',\n", " train=False, \n", " transform=transforms.ToTensor())\n", "\n", "'''\n", "Step 2\n", "'''\n", "class LeNet(nn.Module) :\n", " \n", " def __init__(self) :\n", " super(LeNet, self).__init__()\n", " \n", " self.conv_layer1 = nn.Sequential(\n", " nn.Conv2d(3, 6, kernel_size=5),\n", " nn.Tanh()\n", " )\n", " self.pool_layer1 = nn.Sequential(\n", " nn.AvgPool2d(kernel_size=2, stride=2),\n", " nn.Tanh()\n", " )\n", " self.conv_layer2 = nn.Sequential(\n", " nn.Conv2d(6, 16, kernel_size=5),\n", " nn.Tanh()\n", " )\n", " self.pool_layer2 = nn.Sequential(\n", " nn.AvgPool2d(kernel_size=2, stride=2),\n", " nn.Tanh()\n", " )\n", " self.C5_layer = nn.Sequential(\n", " nn.Linear(5*5*16, 120),\n", " nn.Tanh()\n", " )\n", " self.fc_layer1 = nn.Sequential(\n", " nn.Linear(120, 84),\n", " nn.Tanh()\n", " )\n", " self.fc_layer2 = nn.Linear(84, 10)\n", " \n", " \n", " def forward(self, x) :\n", " output = self.conv_layer1(x)\n", " output = self.pool_layer1(output)\n", " output = self.conv_layer2(output)\n", " output = self.pool_layer2(output)\n", " output = output.view(-1,5*5*16)\n", " output = self.C5_layer(output)\n", " output = self.fc_layer1(output)\n", " output = self.fc_layer2(output)\n", " return output\n", "\n", " \n", "'''\n", "Step 3\n", "'''\n", "model = LeNet().to(device)\n", "loss_function = torch.nn.CrossEntropyLoss()\n", "optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)\n", "\n", "'''\n", "Step 4\n", "'''\n", "train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=1024, shuffle=True)\n", "\n", "import time\n", "start = time.time()\n", "for epoch in range(10) :\n", " print(\"{}th epoch starting.\".format(epoch))\n", " for images, labels in train_loader :\n", " images, labels = images.to(device), labels.to(device)\n", " \n", " optimizer.zero_grad()\n", " train_loss = loss_function(model(images), labels)\n", " train_loss.backward()\n", "\n", " optimizer.step()\n", "end = time.time()\n", "print(\"Time ellapsed in training is: {}\".format(end - start))\n", "\n", "\n", "'''\n", "Step 5\n", "'''\n", "test_loss, correct, total = 0, 0, 0\n", "\n", "test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1024, shuffle=False)\n", "\n", "for images, labels in test_loader :\n", " images, labels = images.to(device), labels.to(device)\n", "\n", " output = model(images)\n", " test_loss += loss_function(output, labels).item()\n", "\n", " pred = output.max(1, keepdim=True)[1]\n", " correct += pred.eq(labels.view_as(pred)).sum().item()\n", " \n", " total += labels.size(0)\n", " \n", "print('[Test set] Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\\n'.format(\n", " test_loss /total, correct, total,\n", " 100. * correct / total))" ] }, { "cell_type": "markdown", "id": "9e708bd2", "metadata": {}, "source": [ "Modern LeNet5 on CIFAR10 with data augmentation\n", "\n", "Result:\n", "[Test set] Average loss: 0.0089, Accuracy: 6885/10000 (68.85%)" ] }, { "cell_type": "code", "execution_count": 6, "id": "4cd78b5e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n", "0th epoch starting.\n", "1th epoch starting.\n", "2th epoch starting.\n", "3th epoch starting.\n", "4th epoch starting.\n", "5th epoch starting.\n", "6th epoch starting.\n", "7th epoch starting.\n", "8th epoch starting.\n", "9th epoch starting.\n", "10th epoch starting.\n", "11th epoch starting.\n", "12th epoch starting.\n", "13th epoch starting.\n", "14th epoch starting.\n", "15th epoch starting.\n", "16th epoch starting.\n", "17th epoch starting.\n", "18th epoch starting.\n", "19th epoch starting.\n", "20th epoch starting.\n", "21th epoch starting.\n", "22th epoch starting.\n", "23th epoch starting.\n", "24th epoch starting.\n", "25th epoch starting.\n", "26th epoch starting.\n", "27th epoch starting.\n", "28th epoch starting.\n", "29th epoch starting.\n", "30th epoch starting.\n", "31th epoch starting.\n", "32th epoch starting.\n", "33th epoch starting.\n", "34th epoch starting.\n", "35th epoch starting.\n", "36th epoch starting.\n", "37th epoch starting.\n", "38th epoch starting.\n", "39th epoch starting.\n", "40th epoch starting.\n", "41th epoch starting.\n", "42th epoch starting.\n", "43th epoch starting.\n", "44th epoch starting.\n", "45th epoch starting.\n", "46th epoch starting.\n", "47th epoch starting.\n", "48th epoch starting.\n", "49th epoch starting.\n", "50th epoch starting.\n", "51th epoch starting.\n", "52th epoch starting.\n", "53th epoch starting.\n", "54th epoch starting.\n", "55th epoch starting.\n", "56th epoch starting.\n", "57th epoch starting.\n", "58th epoch starting.\n", "59th epoch starting.\n", "60th epoch starting.\n", "61th epoch starting.\n", "62th epoch starting.\n", "63th epoch starting.\n", "64th epoch starting.\n", "65th epoch starting.\n", "66th epoch starting.\n", "67th epoch starting.\n", "68th epoch starting.\n", "69th epoch starting.\n", "70th epoch starting.\n", "71th epoch starting.\n", "72th epoch starting.\n", "73th epoch starting.\n", "74th epoch starting.\n", "75th epoch starting.\n", "76th epoch starting.\n", "77th epoch starting.\n", "78th epoch starting.\n", "79th epoch starting.\n", "80th epoch starting.\n", "81th epoch starting.\n", "82th epoch starting.\n", "83th epoch starting.\n", "84th epoch starting.\n", "85th epoch starting.\n", "86th epoch starting.\n", "87th epoch starting.\n", "88th epoch starting.\n", "89th epoch starting.\n", "90th epoch starting.\n", "91th epoch starting.\n", "92th epoch starting.\n", "93th epoch starting.\n", "94th epoch starting.\n", "95th epoch starting.\n", "96th epoch starting.\n", "97th epoch starting.\n", "98th epoch starting.\n", "99th epoch starting.\n", "100th epoch starting.\n", "101th epoch starting.\n", "102th epoch starting.\n", "103th epoch starting.\n", "104th epoch starting.\n", "105th epoch starting.\n", "106th epoch starting.\n", "107th epoch starting.\n", "108th epoch starting.\n", "109th epoch starting.\n", "110th epoch starting.\n", "111th epoch starting.\n", "112th epoch starting.\n", "113th epoch starting.\n", "114th epoch starting.\n", "115th epoch starting.\n", "116th epoch starting.\n", "117th epoch starting.\n", "118th epoch starting.\n", "119th epoch starting.\n", "120th epoch starting.\n", "121th epoch starting.\n", "122th epoch starting.\n", "123th epoch starting.\n", "124th epoch starting.\n", "125th epoch starting.\n", "126th epoch starting.\n", "127th epoch starting.\n", "128th epoch starting.\n", "129th epoch starting.\n", "130th epoch starting.\n", "131th epoch starting.\n", "132th epoch starting.\n", "133th epoch starting.\n", "134th epoch starting.\n", "135th epoch starting.\n", "136th epoch starting.\n", "137th epoch starting.\n", "138th epoch starting.\n", "139th epoch starting.\n", "140th epoch starting.\n", "141th epoch starting.\n", "142th epoch starting.\n", "143th epoch starting.\n", "144th epoch starting.\n", "145th epoch starting.\n", "146th epoch starting.\n", "147th epoch starting.\n", "148th epoch starting.\n", "149th epoch starting.\n", "150th epoch starting.\n", "151th epoch starting.\n", "152th epoch starting.\n", "153th epoch starting.\n", "154th epoch starting.\n", "155th epoch starting.\n", "156th epoch starting.\n", "157th epoch starting.\n", "158th epoch starting.\n", "159th epoch starting.\n", "160th epoch starting.\n", "161th epoch starting.\n", "162th epoch starting.\n", "163th epoch starting.\n", "164th epoch starting.\n", "165th epoch starting.\n", "166th epoch starting.\n", "167th epoch starting.\n", "168th epoch starting.\n", "169th epoch starting.\n", "170th epoch starting.\n", "171th epoch starting.\n", "172th epoch starting.\n", "173th epoch starting.\n", "174th epoch starting.\n", "175th epoch starting.\n", "176th epoch starting.\n", "177th epoch starting.\n", "178th epoch starting.\n", "179th epoch starting.\n", "180th epoch starting.\n", "181th epoch starting.\n", "182th epoch starting.\n", "183th epoch starting.\n", "184th epoch starting.\n", "185th epoch starting.\n", "186th epoch starting.\n", "187th epoch starting.\n", "188th epoch starting.\n", "189th epoch starting.\n", "190th epoch starting.\n", "191th epoch starting.\n", "192th epoch starting.\n", "193th epoch starting.\n", "194th epoch starting.\n", "195th epoch starting.\n", "196th epoch starting.\n", "197th epoch starting.\n", "198th epoch starting.\n", "199th epoch starting.\n", "Time ellapsed in training is: 2079.4606173038483\n", "[Test set] Average loss: 0.0089, Accuracy: 6885/10000 (68.85%)\n", "\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.optim import Optimizer\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets\n", "from torchvision.transforms import transforms\n", "\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "'''\n", "Step 1\n", "'''\n", "\n", "# Image preprocessing modules\n", "transform = transforms.Compose([\n", " transforms.Pad(4),\n", " transforms.RandomHorizontalFlip(),\n", " transforms.RandomCrop(32),\n", " transforms.ToTensor()])\n", "\n", "train_dataset = datasets.CIFAR10(root='./cifar_10data/',\n", " train=True, \n", " transform=transform,\n", " download=True)\n", "\n", "test_dataset = datasets.CIFAR10(root='./cifar_10data/',\n", " train=False, \n", " transform=transforms.ToTensor())\n", "\n", "'''\n", "Step 2\n", "'''\n", "class LeNet(nn.Module) :\n", " \n", " def __init__(self) :\n", " super(LeNet, self).__init__()\n", " \n", " self.conv_layer1 = nn.Sequential(\n", " nn.Conv2d(3, 6, kernel_size=5),\n", " nn.ReLU()\n", " )\n", " self.pool_layer1 = nn.MaxPool2d(kernel_size=2, stride=2)\n", " self.conv_layer2 = nn.Sequential(\n", " nn.Conv2d(6, 16, kernel_size=5),\n", " nn.ReLU()\n", " )\n", " self.pool_layer2 = nn.MaxPool2d(kernel_size=2, stride=2) \n", " self.C5_layer = nn.Sequential(\n", " nn.Linear(5*5*16, 120),\n", " nn.ReLU()\n", " )\n", " self.fc_layer1 = nn.Sequential(\n", " nn.Linear(120, 84),\n", " nn.ReLU()\n", " )\n", " self.fc_layer2 = nn.Linear(84, 10)\n", " \n", " \n", " def forward(self, x) :\n", " output = self.conv_layer1(x)\n", " output = self.pool_layer1(output)\n", " output = self.conv_layer2(output)\n", " output = self.pool_layer2(output)\n", " output = output.view(-1,5*5*16)\n", " output = self.C5_layer(output)\n", " output = self.fc_layer1(output)\n", " output = self.fc_layer2(output)\n", " return output\n", "\n", " \n", "'''\n", "Step 3\n", "'''\n", "model = LeNet().to(device)\n", "loss_function = torch.nn.CrossEntropyLoss()\n", "optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)\n", "\n", "'''\n", "Step 4\n", "'''\n", "train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100, shuffle=True)\n", "\n", "import time\n", "start = time.time()\n", "for epoch in range(200) :\n", " print(\"{}th epoch starting.\".format(epoch))\n", " for images, labels in train_loader :\n", " images, labels = images.to(device), labels.to(device)\n", " \n", " optimizer.zero_grad()\n", " train_loss = loss_function(model(images), labels)\n", " train_loss.backward()\n", "\n", " optimizer.step()\n", "end = time.time()\n", "print(\"Time ellapsed in training is: {}\".format(end - start))\n", "\n", "\n", "'''\n", "Step 5\n", "'''\n", "test_loss, correct, total = 0, 0, 0\n", "\n", "test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=100, shuffle=False)\n", "\n", "for images, labels in test_loader :\n", " images, labels = images.to(device), labels.to(device)\n", "\n", " output = model(images)\n", " test_loss += loss_function(output, labels).item()\n", "\n", " pred = output.max(1, keepdim=True)[1]\n", " correct += pred.eq(labels.view_as(pred)).sum().item()\n", " \n", " total += labels.size(0)\n", " \n", "print('[Test set] Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\\n'.format(\n", " test_loss /total, correct, total,\n", " 100. * correct / total))" ] }, { "cell_type": "markdown", "id": "28ec9e50", "metadata": {}, "source": [ "Dropout + modern LeNet5 + data augmentation\n", "\n", "Results: [Test set] Average loss: 0.0091, Accuracy: 6817/10000 (68.17%)" ] }, { "cell_type": "code", "execution_count": 1, "id": "bfed7196", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n", "0th epoch starting.\n", "1th epoch starting.\n", "2th epoch starting.\n", "3th epoch starting.\n", "4th epoch starting.\n", "5th epoch starting.\n", "6th epoch starting.\n", "7th epoch starting.\n", "8th epoch starting.\n", "9th epoch starting.\n", "10th epoch starting.\n", "11th epoch starting.\n", "12th epoch starting.\n", "13th epoch starting.\n", "14th epoch starting.\n", "15th epoch starting.\n", "16th epoch starting.\n", "17th epoch starting.\n", "18th epoch starting.\n", "19th epoch starting.\n", "20th epoch starting.\n", "21th epoch starting.\n", "22th epoch starting.\n", "23th epoch starting.\n", "24th epoch starting.\n", "25th epoch starting.\n", "26th epoch starting.\n", "27th epoch starting.\n", "28th epoch starting.\n", "29th epoch starting.\n", "30th epoch starting.\n", "31th epoch starting.\n", "32th epoch starting.\n", "33th epoch starting.\n", "34th epoch starting.\n", "35th epoch starting.\n", "36th epoch starting.\n", "37th epoch starting.\n", "38th epoch starting.\n", "39th epoch starting.\n", "40th epoch starting.\n", "41th epoch starting.\n", "42th epoch starting.\n", "43th epoch starting.\n", "44th epoch starting.\n", "45th epoch starting.\n", "46th epoch starting.\n", "47th epoch starting.\n", "48th epoch starting.\n", "49th epoch starting.\n", "50th epoch starting.\n", "51th epoch starting.\n", "52th epoch starting.\n", "53th epoch starting.\n", "54th epoch starting.\n", "55th epoch starting.\n", "56th epoch starting.\n", "57th epoch starting.\n", "58th epoch starting.\n", "59th epoch starting.\n", "60th epoch starting.\n", "61th epoch starting.\n", "62th epoch starting.\n", "63th epoch starting.\n", "64th epoch starting.\n", "65th epoch starting.\n", "66th epoch starting.\n", "67th epoch starting.\n", "68th epoch starting.\n", "69th epoch starting.\n", "70th epoch starting.\n", "71th epoch starting.\n", "72th epoch starting.\n", "73th epoch starting.\n", "74th epoch starting.\n", "75th epoch starting.\n", "76th epoch starting.\n", "77th epoch starting.\n", "78th epoch starting.\n", "79th epoch starting.\n", "80th epoch starting.\n", "81th epoch starting.\n", "82th epoch starting.\n", "83th epoch starting.\n", "84th epoch starting.\n", "85th epoch starting.\n", "86th epoch starting.\n", "87th epoch starting.\n", "88th epoch starting.\n", "89th epoch starting.\n", "90th epoch starting.\n", "91th epoch starting.\n", "92th epoch starting.\n", "93th epoch starting.\n", "94th epoch starting.\n", "95th epoch starting.\n", "96th epoch starting.\n", "97th epoch starting.\n", "98th epoch starting.\n", "99th epoch starting.\n", "100th epoch starting.\n", "101th epoch starting.\n", "102th epoch starting.\n", "103th epoch starting.\n", "104th epoch starting.\n", "105th epoch starting.\n", "106th epoch starting.\n", "107th epoch starting.\n", "108th epoch starting.\n", "109th epoch starting.\n", "110th epoch starting.\n", "111th epoch starting.\n", "112th epoch starting.\n", "113th epoch starting.\n", "114th epoch starting.\n", "115th epoch starting.\n", "116th epoch starting.\n", "117th epoch starting.\n", "118th epoch starting.\n", "119th epoch starting.\n", "120th epoch starting.\n", "121th epoch starting.\n", "122th epoch starting.\n", "123th epoch starting.\n", "124th epoch starting.\n", "125th epoch starting.\n", "126th epoch starting.\n", "127th epoch starting.\n", "128th epoch starting.\n", "129th epoch starting.\n", "130th epoch starting.\n", "131th epoch starting.\n", "132th epoch starting.\n", "133th epoch starting.\n", "134th epoch starting.\n", "135th epoch starting.\n", "136th epoch starting.\n", "137th epoch starting.\n", "138th epoch starting.\n", "139th epoch starting.\n", "140th epoch starting.\n", "141th epoch starting.\n", "142th epoch starting.\n", "143th epoch starting.\n", "144th epoch starting.\n", "145th epoch starting.\n", "146th epoch starting.\n", "147th epoch starting.\n", "148th epoch starting.\n", "149th epoch starting.\n", "150th epoch starting.\n", "151th epoch starting.\n", "152th epoch starting.\n", "153th epoch starting.\n", "154th epoch starting.\n", "155th epoch starting.\n", "156th epoch starting.\n", "157th epoch starting.\n", "158th epoch starting.\n", "159th epoch starting.\n", "160th epoch starting.\n", "161th epoch starting.\n", "162th epoch starting.\n", "163th epoch starting.\n", "164th epoch starting.\n", "165th epoch starting.\n", "166th epoch starting.\n", "167th epoch starting.\n", "168th epoch starting.\n", "169th epoch starting.\n", "170th epoch starting.\n", "171th epoch starting.\n", "172th epoch starting.\n", "173th epoch starting.\n", "174th epoch starting.\n", "175th epoch starting.\n", "176th epoch starting.\n", "177th epoch starting.\n", "178th epoch starting.\n", "179th epoch starting.\n", "180th epoch starting.\n", "181th epoch starting.\n", "182th epoch starting.\n", "183th epoch starting.\n", "184th epoch starting.\n", "185th epoch starting.\n", "186th epoch starting.\n", "187th epoch starting.\n", "188th epoch starting.\n", "189th epoch starting.\n", "190th epoch starting.\n", "191th epoch starting.\n", "192th epoch starting.\n", "193th epoch starting.\n", "194th epoch starting.\n", "195th epoch starting.\n", "196th epoch starting.\n", "197th epoch starting.\n", "198th epoch starting.\n", "199th epoch starting.\n", "Time ellapsed in training is: 2864.05353975296\n", "[Test set] Average loss: 0.0091, Accuracy: 6817/10000 (68.17%)\n", "\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.optim import Optimizer\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets\n", "from torchvision.transforms import transforms\n", "\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "'''\n", "Step 1:\n", "'''\n", "\n", "# Image preprocessing modules\n", "transform = transforms.Compose([\n", " transforms.Pad(4),\n", " transforms.RandomHorizontalFlip(),\n", " transforms.RandomCrop(32),\n", " transforms.ToTensor()])\n", "\n", "train_dataset = datasets.CIFAR10(root='./cifar_10data/',\n", " train=True, \n", " transform=transform,\n", " download=True)\n", "\n", "test_dataset = datasets.CIFAR10(root='./cifar_10data/',\n", " train=False, \n", " transform=transforms.ToTensor())\n", "\n", "'''\n", "Step 2\n", "'''\n", "class LeNet(nn.Module) :\n", "\n", " def __init__(self) :\n", " super(LeNet, self).__init__()\n", " p_drop = 0.1\n", " \n", " self.conv_layer1 = nn.Sequential(\n", " nn.Conv2d(3, 6, kernel_size=5),\n", " nn.ReLU()\n", " )\n", " self.pool_layer1 = nn.MaxPool2d(kernel_size=2, stride=2)\n", " self.conv_layer2 = nn.Sequential(\n", " nn.Conv2d(6, 16, kernel_size=5),\n", " nn.ReLU()\n", " )\n", " self.pool_layer2 = nn.MaxPool2d(kernel_size=2, stride=2) \n", " self.C5_layer = nn.Sequential(\n", " nn.Dropout2d(p=p_drop),\n", " nn.Linear(5*5*16, 120),\n", " nn.ReLU()\n", " )\n", " self.fc_layer1 = nn.Sequential(\n", " nn.Dropout2d(p=p_drop),\n", " nn.Linear(120, 84),\n", " nn.ReLU()\n", " )\n", " self.fc_layer2 = nn.Linear(84, 10)\n", "\n", "\n", " def forward(self, x) :\n", " output = self.conv_layer1(x)\n", " output = self.pool_layer1(output)\n", " output = self.conv_layer2(output)\n", " output = self.pool_layer2(output)\n", " output = output.view(-1,5*5*16)\n", " output = self.C5_layer(output)\n", " output = self.fc_layer1(output)\n", " output = self.fc_layer2(output)\n", " return output\n", "\n", "'''\n", "Step 3\n", "'''\n", "model = LeNet().to(device)\n", "loss_function = torch.nn.CrossEntropyLoss()\n", "optimizer = torch.optim.SGD(model.parameters(), lr=0.01)\n", "\n", "'''\n", "Step 4\n", "'''\n", "model.train()\n", "train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100, shuffle=True)\n", "\n", "import time\n", "start = time.time()\n", "for epoch in range(200) :\n", " print(\"{}th epoch starting.\".format(epoch))\n", " for i, (images, labels) in enumerate(train_loader) :\n", " images, labels = images.to(device), labels.to(device)\n", "\n", " optimizer.zero_grad()\n", " train_loss = loss_function(model(images), labels)\n", " train_loss.backward()\n", "\n", " optimizer.step()\n", "\n", "end = time.time()\n", "print(\"Time ellapsed in training is: {}\".format(end - start))\n", "\n", "\n", "'''\n", "Step 5\n", "'''\n", "model.eval()\n", "test_loss, correct, total = 0, 0, 0\n", "\n", "test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=100, shuffle=False)\n", "\n", "for images, labels in test_loader :\n", " images, labels = images.to(device), labels.to(device)\n", "\n", " output = model(images)\n", " test_loss += loss_function(output, labels).item()\n", "\n", " pred = output.max(1, keepdim=True)[1]\n", " correct += pred.eq(labels.view_as(pred)).sum().item()\n", "\n", " total += labels.size(0)\n", "\n", "print('[Test set] Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\\n'.format(\n", " test_loss /total, correct, total,\n", " 100. * correct / total))" ] }, { "cell_type": "markdown", "id": "e2027950", "metadata": {}, "source": [ "Weight decay + modern LeNet5 + data augmentation\n", "\n", "Result: [Test set] Average loss: 0.0084, Accuracy: 7041/10000 (70.41%)" ] }, { "cell_type": "code", "execution_count": 2, "id": "95e81493", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n", "0th epoch starting.\n", "1th epoch starting.\n", "2th epoch starting.\n", "3th epoch starting.\n", "4th epoch starting.\n", "5th epoch starting.\n", "6th epoch starting.\n", "7th epoch starting.\n", "8th epoch starting.\n", "9th epoch starting.\n", "10th epoch starting.\n", "11th epoch starting.\n", "12th epoch starting.\n", "13th epoch starting.\n", "14th epoch starting.\n", "15th epoch starting.\n", "16th epoch starting.\n", "17th epoch starting.\n", "18th epoch starting.\n", "19th epoch starting.\n", "20th epoch starting.\n", "21th epoch starting.\n", "22th epoch starting.\n", "23th epoch starting.\n", "24th epoch starting.\n", "25th epoch starting.\n", "26th epoch starting.\n", "27th epoch starting.\n", "28th epoch starting.\n", "29th epoch starting.\n", "30th epoch starting.\n", "31th epoch starting.\n", "32th epoch starting.\n", "33th epoch starting.\n", "34th epoch starting.\n", "35th epoch starting.\n", "36th epoch starting.\n", "37th epoch starting.\n", "38th epoch starting.\n", "39th epoch starting.\n", "40th epoch starting.\n", "41th epoch starting.\n", "42th epoch starting.\n", "43th epoch starting.\n", "44th epoch starting.\n", "45th epoch starting.\n", "46th epoch starting.\n", "47th epoch starting.\n", "48th epoch starting.\n", "49th epoch starting.\n", "50th epoch starting.\n", "51th epoch starting.\n", "52th epoch starting.\n", "53th epoch starting.\n", "54th epoch starting.\n", "55th epoch starting.\n", "56th epoch starting.\n", "57th epoch starting.\n", "58th epoch starting.\n", "59th epoch starting.\n", "60th epoch starting.\n", "61th epoch starting.\n", "62th epoch starting.\n", "63th epoch starting.\n", "64th epoch starting.\n", "65th epoch starting.\n", "66th epoch starting.\n", "67th epoch starting.\n", "68th epoch starting.\n", "69th epoch starting.\n", "70th epoch starting.\n", "71th epoch starting.\n", "72th epoch starting.\n", "73th epoch starting.\n", "74th epoch starting.\n", "75th epoch starting.\n", "76th epoch starting.\n", "77th epoch starting.\n", "78th epoch starting.\n", "79th epoch starting.\n", "80th epoch starting.\n", "81th epoch starting.\n", "82th epoch starting.\n", "83th epoch starting.\n", "84th epoch starting.\n", "85th epoch starting.\n", "86th epoch starting.\n", "87th epoch starting.\n", "88th epoch starting.\n", "89th epoch starting.\n", "90th epoch starting.\n", "91th epoch starting.\n", "92th epoch starting.\n", "93th epoch starting.\n", "94th epoch starting.\n", "95th epoch starting.\n", "96th epoch starting.\n", "97th epoch starting.\n", "98th epoch starting.\n", "99th epoch starting.\n", "100th epoch starting.\n", "101th epoch starting.\n", "102th epoch starting.\n", "103th epoch starting.\n", "104th epoch starting.\n", "105th epoch starting.\n", "106th epoch starting.\n", "107th epoch starting.\n", "108th epoch starting.\n", "109th epoch starting.\n", "110th epoch starting.\n", "111th epoch starting.\n", "112th epoch starting.\n", "113th epoch starting.\n", "114th epoch starting.\n", "115th epoch starting.\n", "116th epoch starting.\n", "117th epoch starting.\n", "118th epoch starting.\n", "119th epoch starting.\n", "120th epoch starting.\n", "121th epoch starting.\n", "122th epoch starting.\n", "123th epoch starting.\n", "124th epoch starting.\n", "125th epoch starting.\n", "126th epoch starting.\n", "127th epoch starting.\n", "128th epoch starting.\n", "129th epoch starting.\n", "130th epoch starting.\n", "131th epoch starting.\n", "132th epoch starting.\n", "133th epoch starting.\n", "134th epoch starting.\n", "135th epoch starting.\n", "136th epoch starting.\n", "137th epoch starting.\n", "138th epoch starting.\n", "139th epoch starting.\n", "140th epoch starting.\n", "141th epoch starting.\n", "142th epoch starting.\n", "143th epoch starting.\n", "144th epoch starting.\n", "145th epoch starting.\n", "146th epoch starting.\n", "147th epoch starting.\n", "148th epoch starting.\n", "149th epoch starting.\n", "150th epoch starting.\n", "151th epoch starting.\n", "152th epoch starting.\n", "153th epoch starting.\n", "154th epoch starting.\n", "155th epoch starting.\n", "156th epoch starting.\n", "157th epoch starting.\n", "158th epoch starting.\n", "159th epoch starting.\n", "160th epoch starting.\n", "161th epoch starting.\n", "162th epoch starting.\n", "163th epoch starting.\n", "164th epoch starting.\n", "165th epoch starting.\n", "166th epoch starting.\n", "167th epoch starting.\n", "168th epoch starting.\n", "169th epoch starting.\n", "170th epoch starting.\n", "171th epoch starting.\n", "172th epoch starting.\n", "173th epoch starting.\n", "174th epoch starting.\n", "175th epoch starting.\n", "176th epoch starting.\n", "177th epoch starting.\n", "178th epoch starting.\n", "179th epoch starting.\n", "180th epoch starting.\n", "181th epoch starting.\n", "182th epoch starting.\n", "183th epoch starting.\n", "184th epoch starting.\n", "185th epoch starting.\n", "186th epoch starting.\n", "187th epoch starting.\n", "188th epoch starting.\n", "189th epoch starting.\n", "190th epoch starting.\n", "191th epoch starting.\n", "192th epoch starting.\n", "193th epoch starting.\n", "194th epoch starting.\n", "195th epoch starting.\n", "196th epoch starting.\n", "197th epoch starting.\n", "198th epoch starting.\n", "199th epoch starting.\n", "Time ellapsed in training is: 2842.564428806305\n", "[Test set] Average loss: 0.0084, Accuracy: 7041/10000 (70.41%)\n", "\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.optim import Optimizer\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets\n", "from torchvision.transforms import transforms\n", "\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "'''\n", "Step 1:\n", "'''\n", "\n", "# Image preprocessing modules\n", "transform = transforms.Compose([\n", " transforms.Pad(4),\n", " transforms.RandomHorizontalFlip(),\n", " transforms.RandomCrop(32),\n", " transforms.ToTensor()])\n", "\n", "train_dataset = datasets.CIFAR10(root='./cifar_10data/',\n", " train=True, \n", " transform=transform,\n", " download=True)\n", "\n", "test_dataset = datasets.CIFAR10(root='./cifar_10data/',\n", " train=False, \n", " transform=transforms.ToTensor())\n", " \n", "'''\n", "Step 2\n", "'''\n", "class LeNet(nn.Module) :\n", "\n", " def __init__(self) :\n", " super(LeNet, self).__init__()\n", "\n", " self.conv_layer1 = nn.Sequential(\n", " nn.Conv2d(3, 6, kernel_size=5),\n", " nn.ReLU()\n", " )\n", " self.pool_layer1 = nn.MaxPool2d(kernel_size=2, stride=2)\n", " self.conv_layer2 = nn.Sequential(\n", " nn.Conv2d(6, 16, kernel_size=5),\n", " nn.ReLU()\n", " )\n", " self.pool_layer2 = nn.MaxPool2d(kernel_size=2, stride=2) \n", " self.C5_layer = nn.Sequential(\n", " nn.Linear(5*5*16, 120),\n", " nn.ReLU()\n", " )\n", " self.fc_layer1 = nn.Sequential(\n", " nn.Linear(120, 84),\n", " nn.ReLU()\n", " )\n", " self.fc_layer2 = nn.Linear(84, 10)\n", "\n", "\n", " def forward(self, x) :\n", " output = self.conv_layer1(x)\n", " output = self.pool_layer1(output)\n", " output = self.conv_layer2(output)\n", " output = self.pool_layer2(output)\n", " output = output.view(-1,5*5*16)\n", " output = self.C5_layer(output)\n", " output = self.fc_layer1(output)\n", " output = self.fc_layer2(output)\n", " return output\n", "\n", "'''\n", "Step 3\n", "'''\n", "model = LeNet().to(device)\n", "loss_function = torch.nn.CrossEntropyLoss()\n", "optimizer = torch.optim.SGD(model.parameters(), lr=0.01, weight_decay=1e-05)\n", "\n", "'''\n", "Step 4\n", "'''\n", "train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100, shuffle=True)\n", "\n", "import time\n", "start = time.time()\n", "for epoch in range(200) :\n", " print(\"{}th epoch starting.\".format(epoch))\n", " for i, (images, labels) in enumerate(train_loader) :\n", " images, labels = images.to(device), labels.to(device)\n", "\n", " optimizer.zero_grad()\n", " train_loss = loss_function(model(images), labels)\n", " train_loss.backward()\n", "\n", " optimizer.step()\n", "\n", "end = time.time()\n", "print(\"Time ellapsed in training is: {}\".format(end - start))\n", "\n", "\n", "'''\n", "Step 5\n", "'''\n", "test_loss, correct, total = 0, 0, 0\n", "\n", "test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=100, shuffle=False)\n", "\n", "for images, labels in test_loader :\n", " images, labels = images.to(device), labels.to(device)\n", "\n", " output = model(images)\n", " test_loss += loss_function(output, labels).item()\n", "\n", " pred = output.max(1, keepdim=True)[1]\n", " correct += pred.eq(labels.view_as(pred)).sum().item()\n", "\n", " total += labels.size(0)\n", "\n", "print('[Test set] Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\\n'.format(\n", " test_loss /total, correct, total,\n", " 100. * correct / total))" ] }, { "cell_type": "markdown", "id": "fd340214", "metadata": {}, "source": [ "Weight decay + dropout + modern LeNet5 + data augmentation\n", "\n", "Results: [Test set] Average loss: 0.0090, Accuracy: 6893/10000 (68.93%)" ] }, { "cell_type": "code", "execution_count": 3, "id": "a857d0db", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n", "0th epoch starting.\n", "1th epoch starting.\n", "2th epoch starting.\n", "3th epoch starting.\n", "4th epoch starting.\n", "5th epoch starting.\n", "6th epoch starting.\n", "7th epoch starting.\n", "8th epoch starting.\n", "9th epoch starting.\n", "10th epoch starting.\n", "11th epoch starting.\n", "12th epoch starting.\n", "13th epoch starting.\n", "14th epoch starting.\n", "15th epoch starting.\n", "16th epoch starting.\n", "17th epoch starting.\n", "18th epoch starting.\n", "19th epoch starting.\n", "20th epoch starting.\n", "21th epoch starting.\n", "22th epoch starting.\n", "23th epoch starting.\n", "24th epoch starting.\n", "25th epoch starting.\n", "26th epoch starting.\n", "27th epoch starting.\n", "28th epoch starting.\n", "29th epoch starting.\n", "30th epoch starting.\n", "31th epoch starting.\n", "32th epoch starting.\n", "33th epoch starting.\n", "34th epoch starting.\n", "35th epoch starting.\n", "36th epoch starting.\n", "37th epoch starting.\n", "38th epoch starting.\n", "39th epoch starting.\n", "40th epoch starting.\n", "41th epoch starting.\n", "42th epoch starting.\n", "43th epoch starting.\n", "44th epoch starting.\n", "45th epoch starting.\n", "46th epoch starting.\n", "47th epoch starting.\n", "48th epoch starting.\n", "49th epoch starting.\n", "50th epoch starting.\n", "51th epoch starting.\n", "52th epoch starting.\n", "53th epoch starting.\n", "54th epoch starting.\n", "55th epoch starting.\n", "56th epoch starting.\n", "57th epoch starting.\n", "58th epoch starting.\n", "59th epoch starting.\n", "60th epoch starting.\n", "61th epoch starting.\n", "62th epoch starting.\n", "63th epoch starting.\n", "64th epoch starting.\n", "65th epoch starting.\n", "66th epoch starting.\n", "67th epoch starting.\n", "68th epoch starting.\n", "69th epoch starting.\n", "70th epoch starting.\n", "71th epoch starting.\n", "72th epoch starting.\n", "73th epoch starting.\n", "74th epoch starting.\n", "75th epoch starting.\n", "76th epoch starting.\n", "77th epoch starting.\n", "78th epoch starting.\n", "79th epoch starting.\n", "80th epoch starting.\n", "81th epoch starting.\n", "82th epoch starting.\n", "83th epoch starting.\n", "84th epoch starting.\n", "85th epoch starting.\n", "86th epoch starting.\n", "87th epoch starting.\n", "88th epoch starting.\n", "89th epoch starting.\n", "90th epoch starting.\n", "91th epoch starting.\n", "92th epoch starting.\n", "93th epoch starting.\n", "94th epoch starting.\n", "95th epoch starting.\n", "96th epoch starting.\n", "97th epoch starting.\n", "98th epoch starting.\n", "99th epoch starting.\n", "100th epoch starting.\n", "101th epoch starting.\n", "102th epoch starting.\n", "103th epoch starting.\n", "104th epoch starting.\n", "105th epoch starting.\n", "106th epoch starting.\n", "107th epoch starting.\n", "108th epoch starting.\n", "109th epoch starting.\n", "110th epoch starting.\n", "111th epoch starting.\n", "112th epoch starting.\n", "113th epoch starting.\n", "114th epoch starting.\n", "115th epoch starting.\n", "116th epoch starting.\n", "117th epoch starting.\n", "118th epoch starting.\n", "119th epoch starting.\n", "120th epoch starting.\n", "121th epoch starting.\n", "122th epoch starting.\n", "123th epoch starting.\n", "124th epoch starting.\n", "125th epoch starting.\n", "126th epoch starting.\n", "127th epoch starting.\n", "128th epoch starting.\n", "129th epoch starting.\n", "130th epoch starting.\n", "131th epoch starting.\n", "132th epoch starting.\n", "133th epoch starting.\n", "134th epoch starting.\n", "135th epoch starting.\n", "136th epoch starting.\n", "137th epoch starting.\n", "138th epoch starting.\n", "139th epoch starting.\n", "140th epoch starting.\n", "141th epoch starting.\n", "142th epoch starting.\n", "143th epoch starting.\n", "144th epoch starting.\n", "145th epoch starting.\n", "146th epoch starting.\n", "147th epoch starting.\n", "148th epoch starting.\n", "149th epoch starting.\n", "150th epoch starting.\n", "151th epoch starting.\n", "152th epoch starting.\n", "153th epoch starting.\n", "154th epoch starting.\n", "155th epoch starting.\n", "156th epoch starting.\n", "157th epoch starting.\n", "158th epoch starting.\n", "159th epoch starting.\n", "160th epoch starting.\n", "161th epoch starting.\n", "162th epoch starting.\n", "163th epoch starting.\n", "164th epoch starting.\n", "165th epoch starting.\n", "166th epoch starting.\n", "167th epoch starting.\n", "168th epoch starting.\n", "169th epoch starting.\n", "170th epoch starting.\n", "171th epoch starting.\n", "172th epoch starting.\n", "173th epoch starting.\n", "174th epoch starting.\n", "175th epoch starting.\n", "176th epoch starting.\n", "177th epoch starting.\n", "178th epoch starting.\n", "179th epoch starting.\n", "180th epoch starting.\n", "181th epoch starting.\n", "182th epoch starting.\n", "183th epoch starting.\n", "184th epoch starting.\n", "185th epoch starting.\n", "186th epoch starting.\n", "187th epoch starting.\n", "188th epoch starting.\n", "189th epoch starting.\n", "190th epoch starting.\n", "191th epoch starting.\n", "192th epoch starting.\n", "193th epoch starting.\n", "194th epoch starting.\n", "195th epoch starting.\n", "196th epoch starting.\n", "197th epoch starting.\n", "198th epoch starting.\n", "199th epoch starting.\n", "Time ellapsed in training is: 2347.501813650131\n", "[Test set] Average loss: 0.0090, Accuracy: 6893/10000 (68.93%)\n", "\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.optim import Optimizer\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets\n", "from torchvision.transforms import transforms\n", "\n", "\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "'''\n", "Step 1:\n", "'''\n", "\n", "# Image preprocessing modules\n", "transform = transforms.Compose([\n", " transforms.Pad(4),\n", " transforms.RandomHorizontalFlip(),\n", " transforms.RandomCrop(32),\n", " transforms.ToTensor()])\n", "\n", "train_dataset = datasets.CIFAR10(root='./cifar_10data/',\n", " train=True, \n", " transform=transform,\n", " download=True)\n", "\n", "test_dataset = datasets.CIFAR10(root='./cifar_10data/',\n", " train=False, \n", " transform=transforms.ToTensor())\n", " \n", "\n", "\n", "'''\n", "Step 2\n", "'''\n", "class LeNet(nn.Module) :\n", "\n", " def __init__(self) :\n", " super(LeNet, self).__init__()\n", " p_drop = 0.1\n", " \n", " self.conv_layer1 = nn.Sequential(\n", " nn.Conv2d(3, 6, kernel_size=5),\n", " nn.ReLU()\n", " )\n", " self.pool_layer1 = nn.MaxPool2d(kernel_size=2, stride=2)\n", " self.conv_layer2 = nn.Sequential(\n", " nn.Conv2d(6, 16, kernel_size=5),\n", " nn.ReLU()\n", " )\n", " self.pool_layer2 = nn.MaxPool2d(kernel_size=2, stride=2) \n", " self.C5_layer = nn.Sequential(\n", " nn.Dropout2d(p=p_drop),\n", " nn.Linear(5*5*16, 120),\n", " nn.ReLU()\n", " )\n", " self.fc_layer1 = nn.Sequential(\n", " nn.Dropout2d(p=p_drop),\n", " nn.Linear(120, 84),\n", " nn.ReLU()\n", " )\n", " self.fc_layer2 = nn.Linear(84, 10)\n", "\n", "\n", " def forward(self, x) :\n", " output = self.conv_layer1(x)\n", " output = self.pool_layer1(output)\n", " output = self.conv_layer2(output)\n", " output = self.pool_layer2(output)\n", " output = output.view(-1,5*5*16)\n", " output = self.C5_layer(output)\n", " output = self.fc_layer1(output)\n", " output = self.fc_layer2(output)\n", " return output\n", "\n", "'''\n", "Step 3\n", "'''\n", "model = LeNet().to(device)\n", "loss_function = torch.nn.CrossEntropyLoss()\n", "optimizer = torch.optim.SGD(model.parameters(), lr=0.01, weight_decay=1e-05)\n", "\n", "'''\n", "Step 4\n", "'''\n", "model.train()\n", "train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100, shuffle=True)\n", "\n", "import time\n", "start = time.time()\n", "for epoch in range(200) :\n", " print(\"{}th epoch starting.\".format(epoch))\n", " for i, (images, labels) in enumerate(train_loader) :\n", " images, labels = images.to(device), labels.to(device)\n", "\n", " optimizer.zero_grad()\n", " train_loss = loss_function(model(images), labels)\n", " train_loss.backward()\n", "\n", " optimizer.step()\n", "\n", "end = time.time()\n", "print(\"Time ellapsed in training is: {}\".format(end - start))\n", "\n", "\n", "'''\n", "Step 5\n", "'''\n", "model.eval()\n", "test_loss, correct, total = 0, 0, 0\n", "\n", "test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=100, shuffle=False)\n", "\n", "for images, labels in test_loader :\n", " images, labels = images.to(device), labels.to(device)\n", "\n", " output = model(images)\n", " test_loss += loss_function(output, labels).item()\n", "\n", " pred = output.max(1, keepdim=True)[1]\n", " correct += pred.eq(labels.view_as(pred)).sum().item()\n", "\n", " total += labels.size(0)\n", "\n", "print('[Test set] Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\\n'.format(\n", " test_loss /total, correct, total,\n", " 100. * correct / total))" ] }, { "cell_type": "markdown", "id": "6bc1c164", "metadata": {}, "source": [ "AlexNet CIFAR10\n", "\n", "Results: [Test set] Average loss: 0.0046, Accuracy: 8852/10000 (88.52%)" ] }, { "cell_type": "code", "execution_count": 4, "id": "54faac92", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n", "0th epoch starting.\n", "Epoch [1] Loss: 2.2951\n", "1th epoch starting.\n", "Epoch [2] Loss: 2.3031\n", "2th epoch starting.\n", "Epoch [3] Loss: 2.2218\n", "3th epoch starting.\n", "Epoch [4] Loss: 2.2851\n", "4th epoch starting.\n", "Epoch [5] Loss: 2.3342\n", "5th epoch starting.\n", "Epoch [6] Loss: 1.8571\n", "6th epoch starting.\n", "Epoch [7] Loss: 1.5752\n", "7th epoch starting.\n", "Epoch [8] Loss: 1.7041\n", "8th epoch starting.\n", "Epoch [9] Loss: 1.3698\n", "9th epoch starting.\n", "Epoch [10] Loss: 1.3407\n", "10th epoch starting.\n", "Epoch [11] Loss: 1.1056\n", "11th epoch starting.\n", "Epoch [12] Loss: 1.3354\n", "12th epoch starting.\n", "Epoch [13] Loss: 0.9303\n", "13th epoch starting.\n", "Epoch [14] Loss: 0.8565\n", "14th epoch starting.\n", "Epoch [15] Loss: 0.9399\n", "15th epoch starting.\n", "Epoch [16] Loss: 0.8309\n", "16th epoch starting.\n", "Epoch [17] Loss: 0.8297\n", "17th epoch starting.\n", "Epoch [18] Loss: 0.6621\n", "18th epoch starting.\n", "Epoch [19] Loss: 0.7288\n", "19th epoch starting.\n", "Epoch [20] Loss: 0.6225\n", "20th epoch starting.\n", "Epoch [21] Loss: 0.8439\n", "21th epoch starting.\n", "Epoch [22] Loss: 0.5396\n", "22th epoch starting.\n", "Epoch [23] Loss: 0.6880\n", "23th epoch starting.\n", "Epoch [24] Loss: 0.5739\n", "24th epoch starting.\n", "Epoch [25] Loss: 0.4420\n", "25th epoch starting.\n", "Epoch [26] Loss: 0.5497\n", "26th epoch starting.\n", "Epoch [27] Loss: 0.5709\n", "27th epoch starting.\n", "Epoch [28] Loss: 0.4651\n", "28th epoch starting.\n", "Epoch [29] Loss: 0.4237\n", "29th epoch starting.\n", "Epoch [30] Loss: 0.3453\n", "30th epoch starting.\n", "Epoch [31] Loss: 0.3758\n", "31th epoch starting.\n", "Epoch [32] Loss: 0.4067\n", "32th epoch starting.\n", "Epoch [33] Loss: 0.2789\n", "33th epoch starting.\n", "Epoch [34] Loss: 0.2558\n", "34th epoch starting.\n", "Epoch [35] Loss: 0.2601\n", "35th epoch starting.\n", "Epoch [36] Loss: 0.2216\n", "36th epoch starting.\n", "Epoch [37] Loss: 0.3736\n", "37th epoch starting.\n", "Epoch [38] Loss: 0.2194\n", "38th epoch starting.\n", "Epoch [39] Loss: 0.2703\n", "39th epoch starting.\n", "Epoch [40] Loss: 0.2777\n", "40th epoch starting.\n", "Epoch [41] Loss: 0.3179\n", "41th epoch starting.\n", "Epoch [42] Loss: 0.3313\n", "42th epoch starting.\n", "Epoch [43] Loss: 0.1698\n", "43th epoch starting.\n", "Epoch [44] Loss: 0.3223\n", "44th epoch starting.\n", "Epoch [45] Loss: 0.2066\n", "45th epoch starting.\n", "Epoch [46] Loss: 0.2802\n", "46th epoch starting.\n", "Epoch [47] Loss: 0.2398\n", "47th epoch starting.\n", "Epoch [48] Loss: 0.1561\n", "48th epoch starting.\n", "Epoch [49] Loss: 0.1068\n", "49th epoch starting.\n", "Epoch [50] Loss: 0.2616\n", "50th epoch starting.\n", "Epoch [51] Loss: 0.2370\n", "51th epoch starting.\n", "Epoch [52] Loss: 0.1557\n", "52th epoch starting.\n", "Epoch [53] Loss: 0.1648\n", "53th epoch starting.\n", "Epoch [54] Loss: 0.1925\n", "54th epoch starting.\n", "Epoch [55] Loss: 0.1652\n", "55th epoch starting.\n", "Epoch [56] Loss: 0.0986\n", "56th epoch starting.\n", "Epoch [57] Loss: 0.1151\n", "57th epoch starting.\n", "Epoch [58] Loss: 0.1569\n", "58th epoch starting.\n", "Epoch [59] Loss: 0.1059\n", "59th epoch starting.\n", "Epoch [60] Loss: 0.1206\n", "60th epoch starting.\n", "Epoch [61] Loss: 0.1260\n", "61th epoch starting.\n", "Epoch [62] Loss: 0.0961\n", "62th epoch starting.\n", "Epoch [63] Loss: 0.1291\n", "63th epoch starting.\n", "Epoch [64] Loss: 0.1102\n", "64th epoch starting.\n", "Epoch [65] Loss: 0.1053\n", "65th epoch starting.\n", "Epoch [66] Loss: 0.1206\n", "66th epoch starting.\n", "Epoch [67] Loss: 0.0752\n", "67th epoch starting.\n", "Epoch [68] Loss: 0.1172\n", "68th epoch starting.\n", "Epoch [69] Loss: 0.1349\n", "69th epoch starting.\n", "Epoch [70] Loss: 0.0974\n", "70th epoch starting.\n", "Epoch [71] Loss: 0.0558\n", "71th epoch starting.\n", "Epoch [72] Loss: 0.1787\n", "72th epoch starting.\n", "Epoch [73] Loss: 0.0934\n", "73th epoch starting.\n", "Epoch [74] Loss: 0.0910\n", "74th epoch starting.\n", "Epoch [75] Loss: 0.0724\n", "75th epoch starting.\n", "Epoch [76] Loss: 0.1328\n", "76th epoch starting.\n", "Epoch [77] Loss: 0.1075\n", "77th epoch starting.\n", "Epoch [78] Loss: 0.0650\n", "78th epoch starting.\n", "Epoch [79] Loss: 0.0463\n", "79th epoch starting.\n", "Epoch [80] Loss: 0.0913\n", "80th epoch starting.\n", "Epoch [81] Loss: 0.1156\n", "81th epoch starting.\n", "Epoch [82] Loss: 0.0200\n", "82th epoch starting.\n", "Epoch [83] Loss: 0.0728\n", "83th epoch starting.\n", "Epoch [84] Loss: 0.0367\n", "84th epoch starting.\n", "Epoch [85] Loss: 0.0890\n", "85th epoch starting.\n", "Epoch [86] Loss: 0.0644\n", "86th epoch starting.\n", "Epoch [87] Loss: 0.0707\n", "87th epoch starting.\n", "Epoch [88] Loss: 0.0370\n", "88th epoch starting.\n", "Epoch [89] Loss: 0.0825\n", "89th epoch starting.\n", "Epoch [90] Loss: 0.0387\n", "90th epoch starting.\n", "Epoch [91] Loss: 0.0486\n", "91th epoch starting.\n", "Epoch [92] Loss: 0.0862\n", "92th epoch starting.\n", "Epoch [93] Loss: 0.0170\n", "93th epoch starting.\n", "Epoch [94] Loss: 0.0438\n", "94th epoch starting.\n", "Epoch [95] Loss: 0.0322\n", "95th epoch starting.\n", "Epoch [96] Loss: 0.0301\n", "96th epoch starting.\n", "Epoch [97] Loss: 0.0175\n", "97th epoch starting.\n", "Epoch [98] Loss: 0.0077\n", "98th epoch starting.\n", "Epoch [99] Loss: 0.0413\n", "99th epoch starting.\n", "Epoch [100] Loss: 0.0976\n", "Time ellapsed in training is: 2575.2117249965668\n", "[Test set] Average loss: 0.0046, Accuracy: 8852/10000 (88.52%)\n", "\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.optim import Optimizer\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets\n", "from torchvision.transforms import transforms\n", "\n", "\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "'''\n", "Step 1:\n", "'''\n", "\n", "transform = transforms.Compose([\n", " transforms.Pad(4),\n", " transforms.RandomHorizontalFlip(),\n", " transforms.RandomCrop(32),\n", " transforms.ToTensor()])\n", "\n", "train_dataset = datasets.CIFAR10(root='./cifar_10data/',\n", " train=True, \n", " transform=transform,\n", " download=True)\n", "\n", "test_dataset = datasets.CIFAR10(root='./cifar_10data/',\n", " train=False, \n", " transform=transforms.ToTensor())\n", " \n", "\n", "\n", "'''\n", "Step 2\n", "'''\n", "\n", "class AlexNet(nn.Module) :\n", " \n", " def __init__(self, num_class=10) :\n", " super(AlexNet, self).__init__()\n", " \n", " self.conv_layer1 = nn.Sequential(\n", " nn.Conv2d(3, 96, kernel_size=4),\n", " nn.ReLU(),\n", " nn.Conv2d(96, 96, kernel_size=3),\n", " nn.ReLU()\n", " )\n", " self.conv_layer2 = nn.Sequential(\n", " nn.Conv2d(96, 256, kernel_size=5, padding=2),\n", " nn.ReLU(),\n", " nn.MaxPool2d(kernel_size=3, stride=2)\n", " )\n", " self.conv_layer3 = nn.Sequential(\n", " nn.Conv2d(256, 384, kernel_size=3, padding=1),\n", " nn.ReLU(),\n", " nn.Conv2d(384, 384, kernel_size=3, padding=1),\n", " nn.ReLU(),\n", " nn.Conv2d(384, 256, kernel_size=3, padding=1),\n", " nn.ReLU(),\n", " nn.MaxPool2d(kernel_size=3, stride=2)\n", " )\n", " \n", " self.fc_layer1 = nn.Sequential(\n", " nn.Dropout(),\n", " nn.Linear(9216, 4096),\n", " nn.ReLU(),\n", " nn.Dropout(), #p=0.5 by default\n", " nn.Linear(4096, 4096),\n", " nn.ReLU(), #p=0.5 by default\n", " nn.Linear(4096, 10)\n", " )\n", " \n", " def forward(self, x) :\n", " output = self.conv_layer1(x)\n", " output = self.conv_layer2(output)\n", " output = self.conv_layer3(output)\n", " output = output.view(-1, 9216)\n", " output = self.fc_layer1(output)\n", " return output\n", "\n", " \n", "\n", "'''\n", "Step 3\n", "'''\n", "model = AlexNet().to(device)\n", "loss_function = torch.nn.CrossEntropyLoss()\n", "optimizer = torch.optim.SGD(model.parameters(), lr=1e-1, weight_decay=0.00005)\n", "\n", "'''\n", "Step 4\n", "'''\n", "model.train()\n", "train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100, shuffle=True)\n", "\n", "import time\n", "start = time.time()\n", "for epoch in range(100) :\n", " print(\"{}th epoch starting.\".format(epoch))\n", " for i, (images, labels) in enumerate(train_loader) :\n", " images, labels = images.to(device), labels.to(device)\n", "\n", " optimizer.zero_grad()\n", " train_loss = loss_function(model(images), labels)\n", " train_loss.backward()\n", "\n", " optimizer.step()\n", "\n", " print (\"Epoch [{}] Loss: {:.4f}\".format(epoch+1, train_loss.item()))\n", "\n", "end = time.time()\n", "print(\"Time ellapsed in training is: {}\".format(end - start))\n", "\n", "\n", "'''\n", "Step 5\n", "'''\n", "model.eval()\n", "test_loss, correct, total = 0, 0, 0\n", "\n", "test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=100, shuffle=False)\n", "with torch.no_grad(): #using context manager\n", " for images, labels in test_loader :\n", " images, labels = images.to(device), labels.to(device)\n", "\n", " output = model(images)\n", " test_loss += loss_function(output, labels).item()\n", "\n", " pred = output.max(1, keepdim=True)[1]\n", " correct += pred.eq(labels.view_as(pred)).sum().item()\n", "\n", " total += labels.size(0)\n", "\n", "print('[Test set] Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\\n'.format(\n", " test_loss /total, correct, total,\n", " 100. * correct / total))" ] }, { "cell_type": "markdown", "id": "5d428b24", "metadata": {}, "source": [ "VGG13 for CIFAR10\n", "\n", "Results: [Test set] Average loss: 0.0041, Accuracy: 8624/10000 (86.24%)\n", "\n", "[Test set] Average loss: 0.0041, Accuracy: 8735/10000 (87.35%)" ] }, { "cell_type": "code", "execution_count": 1, "id": "605908c8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\user\\anaconda3\\lib\\site-packages\\ipykernel_launcher.py:106: UserWarning: nn.init.kaiming_normal is now deprecated in favor of nn.init.kaiming_normal_.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "0th epoch starting.\n", "Epoch [1] Loss: 1.8598\n", "1th epoch starting.\n", "Epoch [2] Loss: 1.3329\n", "2th epoch starting.\n", "Epoch [3] Loss: 1.4340\n", "3th epoch starting.\n", "Epoch [4] Loss: 1.2230\n", "4th epoch starting.\n", "Epoch [5] Loss: 0.9261\n", "5th epoch starting.\n", "Epoch [6] Loss: 1.1918\n", "6th epoch starting.\n", "Epoch [7] Loss: 1.1174\n", "7th epoch starting.\n", "Epoch [8] Loss: 0.4981\n", "8th epoch starting.\n", "Epoch [9] Loss: 0.7398\n", "9th epoch starting.\n", "Epoch [10] Loss: 0.6721\n", "10th epoch starting.\n", "Epoch [11] Loss: 0.6295\n", "11th epoch starting.\n", "Epoch [12] Loss: 0.7846\n", "12th epoch starting.\n", "Epoch [13] Loss: 0.6886\n", "13th epoch starting.\n", "Epoch [14] Loss: 0.5504\n", "14th epoch starting.\n", "Epoch [15] Loss: 0.4254\n", "15th epoch starting.\n", "Epoch [16] Loss: 0.5293\n", "16th epoch starting.\n", "Epoch [17] Loss: 0.3851\n", "17th epoch starting.\n", "Epoch [18] Loss: 0.3241\n", "18th epoch starting.\n", "Epoch [19] Loss: 0.4922\n", "19th epoch starting.\n", "Epoch [20] Loss: 0.2641\n", "20th epoch starting.\n", "Epoch [21] Loss: 0.3612\n", "21th epoch starting.\n", "Epoch [22] Loss: 0.3590\n", "22th epoch starting.\n", "Epoch [23] Loss: 0.3100\n", "23th epoch starting.\n", "Epoch [24] Loss: 0.4929\n", "24th epoch starting.\n", "Epoch [25] Loss: 0.4750\n", "25th epoch starting.\n", "Epoch [26] Loss: 0.2649\n", "26th epoch starting.\n", "Epoch [27] Loss: 0.5087\n", "27th epoch starting.\n", "Epoch [28] Loss: 0.4059\n", "28th epoch starting.\n", "Epoch [29] Loss: 0.4095\n", "29th epoch starting.\n", "Epoch [30] Loss: 0.2972\n", "30th epoch starting.\n", "Epoch [31] Loss: 0.3496\n", "31th epoch starting.\n", "Epoch [32] Loss: 0.3573\n", "32th epoch starting.\n", "Epoch [33] Loss: 0.4307\n", "33th epoch starting.\n", "Epoch [34] Loss: 0.2800\n", "34th epoch starting.\n", "Epoch [35] Loss: 0.3273\n", "35th epoch starting.\n", "Epoch [36] Loss: 0.3902\n", "36th epoch starting.\n", "Epoch [37] Loss: 0.1938\n", "37th epoch starting.\n", "Epoch [38] Loss: 0.2151\n", "38th epoch starting.\n", "Epoch [39] Loss: 0.1740\n", "39th epoch starting.\n", "Epoch [40] Loss: 0.3933\n", "40th epoch starting.\n", "Epoch [41] Loss: 0.2769\n", "41th epoch starting.\n", "Epoch [42] Loss: 0.3096\n", "42th epoch starting.\n", "Epoch [43] Loss: 0.2538\n", "43th epoch starting.\n", "Epoch [44] Loss: 0.5239\n", "44th epoch starting.\n", "Epoch [45] Loss: 0.5060\n", "45th epoch starting.\n", "Epoch [46] Loss: 0.3071\n", "46th epoch starting.\n", "Epoch [47] Loss: 0.4469\n", "47th epoch starting.\n", "Epoch [48] Loss: 0.3187\n", "48th epoch starting.\n", "Epoch [49] Loss: 0.5081\n", "49th epoch starting.\n", "Epoch [50] Loss: 0.2415\n", "50th epoch starting.\n", "Epoch [51] Loss: 0.2221\n", "51th epoch starting.\n", "Epoch [52] Loss: 0.4272\n", "52th epoch starting.\n", "Epoch [53] Loss: 0.3678\n", "53th epoch starting.\n", "Epoch [54] Loss: 0.3635\n", "54th epoch starting.\n", "Epoch [55] Loss: 0.5911\n", "55th epoch starting.\n", "Epoch [56] Loss: 0.2359\n", "56th epoch starting.\n", "Epoch [57] Loss: 0.4155\n", "57th epoch starting.\n", "Epoch [58] Loss: 0.4057\n", "58th epoch starting.\n", "Epoch [59] Loss: 0.2907\n", "59th epoch starting.\n", "Epoch [60] Loss: 0.2244\n", "60th epoch starting.\n", "Epoch [61] Loss: 0.3431\n", "61th epoch starting.\n", "Epoch [62] Loss: 0.2372\n", "62th epoch starting.\n", "Epoch [63] Loss: 0.2788\n", "63th epoch starting.\n", "Epoch [64] Loss: 0.2996\n", "64th epoch starting.\n", "Epoch [65] Loss: 0.3269\n", "65th epoch starting.\n", "Epoch [66] Loss: 0.2890\n", "66th epoch starting.\n", "Epoch [67] Loss: 0.3353\n", "67th epoch starting.\n", "Epoch [68] Loss: 0.4843\n", "68th epoch starting.\n", "Epoch [69] Loss: 0.2919\n", "69th epoch starting.\n", "Epoch [70] Loss: 0.2424\n", "70th epoch starting.\n", "Epoch [71] Loss: 0.6124\n", "71th epoch starting.\n", "Epoch [72] Loss: 0.5113\n", "72th epoch starting.\n", "Epoch [73] Loss: 0.1495\n", "73th epoch starting.\n", "Epoch [74] Loss: 0.3753\n", "74th epoch starting.\n", "Epoch [75] Loss: 0.4598\n", "75th epoch starting.\n", "Epoch [76] Loss: 0.2201\n", "76th epoch starting.\n", "Epoch [77] Loss: 0.3246\n", "77th epoch starting.\n", "Epoch [78] Loss: 0.4311\n", "78th epoch starting.\n", "Epoch [79] Loss: 0.4040\n", "79th epoch starting.\n", "Epoch [80] Loss: 0.3118\n", "80th epoch starting.\n", "Epoch [81] Loss: 0.3347\n", "81th epoch starting.\n", "Epoch [82] Loss: 0.2649\n", "82th epoch starting.\n", "Epoch [83] Loss: 0.2861\n", "83th epoch starting.\n", "Epoch [84] Loss: 0.2879\n", "84th epoch starting.\n", "Epoch [85] Loss: 0.3088\n", "85th epoch starting.\n", "Epoch [86] Loss: 0.4080\n", "86th epoch starting.\n", "Epoch [87] Loss: 0.3232\n", "87th epoch starting.\n", "Epoch [88] Loss: 0.2666\n", "88th epoch starting.\n", "Epoch [89] Loss: 0.2569\n", "89th epoch starting.\n", "Epoch [90] Loss: 0.2753\n", "90th epoch starting.\n", "Epoch [91] Loss: 0.2128\n", "91th epoch starting.\n", "Epoch [92] Loss: 0.2806\n", "92th epoch starting.\n", "Epoch [93] Loss: 0.2354\n", "93th epoch starting.\n", "Epoch [94] Loss: 0.2395\n", "94th epoch starting.\n", "Epoch [95] Loss: 0.1972\n", "95th epoch starting.\n", "Epoch [96] Loss: 0.4187\n", "96th epoch starting.\n", "Epoch [97] Loss: 0.2253\n", "97th epoch starting.\n", "Epoch [98] Loss: 0.4092\n", "98th epoch starting.\n", "Epoch [99] Loss: 0.3417\n", "99th epoch starting.\n", "Epoch [100] Loss: 0.3870\n", "100th epoch starting.\n", "Epoch [101] Loss: 0.1645\n", "101th epoch starting.\n", "Epoch [102] Loss: 0.5333\n", "102th epoch starting.\n", "Epoch [103] Loss: 0.3331\n", "103th epoch starting.\n", "Epoch [104] Loss: 0.1858\n", "104th epoch starting.\n", "Epoch [105] Loss: 0.2963\n", "105th epoch starting.\n", "Epoch [106] Loss: 0.1699\n", "106th epoch starting.\n", "Epoch [107] Loss: 0.3592\n", "107th epoch starting.\n", "Epoch [108] Loss: 0.1415\n", "108th epoch starting.\n", "Epoch [109] Loss: 0.3795\n", "109th epoch starting.\n", "Epoch [110] Loss: 0.2755\n", "110th epoch starting.\n", "Epoch [111] Loss: 0.4953\n", "111th epoch starting.\n", "Epoch [112] Loss: 0.2678\n", "112th epoch starting.\n", "Epoch [113] Loss: 0.2460\n", "113th epoch starting.\n", "Epoch [114] Loss: 0.1959\n", "114th epoch starting.\n", "Epoch [115] Loss: 0.1499\n", "115th epoch starting.\n", "Epoch [116] Loss: 0.2607\n", "116th epoch starting.\n", "Epoch [117] Loss: 0.3296\n", "117th epoch starting.\n", "Epoch [118] Loss: 0.3490\n", "118th epoch starting.\n", "Epoch [119] Loss: 0.1472\n", "119th epoch starting.\n", "Epoch [120] Loss: 0.3513\n", "120th epoch starting.\n", "Epoch [121] Loss: 0.1889\n", "121th epoch starting.\n", "Epoch [122] Loss: 0.2294\n", "122th epoch starting.\n", "Epoch [123] Loss: 0.3001\n", "123th epoch starting.\n", "Epoch [124] Loss: 0.2445\n", "124th epoch starting.\n", "Epoch [125] Loss: 0.4756\n", "125th epoch starting.\n", "Epoch [126] Loss: 0.2591\n", "126th epoch starting.\n", "Epoch [127] Loss: 0.1651\n", "127th epoch starting.\n", "Epoch [128] Loss: 0.2781\n", "128th epoch starting.\n", "Epoch [129] Loss: 0.4770\n", "129th epoch starting.\n", "Epoch [130] Loss: 0.3953\n", "130th epoch starting.\n", "Epoch [131] Loss: 0.2412\n", "131th epoch starting.\n", "Epoch [132] Loss: 0.2473\n", "132th epoch starting.\n", "Epoch [133] Loss: 0.1859\n", "133th epoch starting.\n", "Epoch [134] Loss: 0.2128\n", "134th epoch starting.\n", "Epoch [135] Loss: 0.3209\n", "135th epoch starting.\n", "Epoch [136] Loss: 0.3966\n", "136th epoch starting.\n", "Epoch [137] Loss: 0.4314\n", "137th epoch starting.\n", "Epoch [138] Loss: 0.3053\n", "138th epoch starting.\n", "Epoch [139] Loss: 0.1327\n", "139th epoch starting.\n", "Epoch [140] Loss: 0.3878\n", "140th epoch starting.\n", "Epoch [141] Loss: 0.2671\n", "141th epoch starting.\n", "Epoch [142] Loss: 0.4266\n", "142th epoch starting.\n", "Epoch [143] Loss: 0.2672\n", "143th epoch starting.\n", "Epoch [144] Loss: 0.2187\n", "144th epoch starting.\n", "Epoch [145] Loss: 0.2496\n", "145th epoch starting.\n", "Epoch [146] Loss: 0.4767\n", "146th epoch starting.\n", "Epoch [147] Loss: 0.1342\n", "147th epoch starting.\n", "Epoch [148] Loss: 0.3054\n", "148th epoch starting.\n", "Epoch [149] Loss: 0.2579\n", "149th epoch starting.\n", "Epoch [150] Loss: 0.1920\n", "150th epoch starting.\n", "Epoch [151] Loss: 0.3034\n", "151th epoch starting.\n", "Epoch [152] Loss: 0.2285\n", "152th epoch starting.\n", "Epoch [153] Loss: 0.2735\n", "153th epoch starting.\n", "Epoch [154] Loss: 0.2495\n", "154th epoch starting.\n", "Epoch [155] Loss: 0.3329\n", "155th epoch starting.\n", "Epoch [156] Loss: 0.2605\n", "156th epoch starting.\n", "Epoch [157] Loss: 0.1637\n", "157th epoch starting.\n", "Epoch [158] Loss: 0.2671\n", "158th epoch starting.\n", "Epoch [159] Loss: 0.2275\n", "159th epoch starting.\n", "Epoch [160] Loss: 0.2802\n", "160th epoch starting.\n", "Epoch [161] Loss: 0.2663\n", "161th epoch starting.\n", "Epoch [162] Loss: 0.3131\n", "162th epoch starting.\n", "Epoch [163] Loss: 0.2146\n", "163th epoch starting.\n", "Epoch [164] Loss: 0.1387\n", "164th epoch starting.\n", "Epoch [165] Loss: 0.1990\n", "165th epoch starting.\n", "Epoch [166] Loss: 0.3852\n", "166th epoch starting.\n", "Epoch [167] Loss: 0.2329\n", "167th epoch starting.\n", "Epoch [168] Loss: 0.3393\n", "168th epoch starting.\n", "Epoch [169] Loss: 0.3788\n", "169th epoch starting.\n", "Epoch [170] Loss: 0.2642\n", "170th epoch starting.\n", "Epoch [171] Loss: 0.1443\n", "171th epoch starting.\n", "Epoch [172] Loss: 0.2249\n", "172th epoch starting.\n", "Epoch [173] Loss: 0.3713\n", "173th epoch starting.\n", "Epoch [174] Loss: 0.4962\n", "174th epoch starting.\n", "Epoch [175] Loss: 0.2282\n", "175th epoch starting.\n", "Epoch [176] Loss: 0.3637\n", "176th epoch starting.\n", "Epoch [177] Loss: 0.2662\n", "177th epoch starting.\n", "Epoch [178] Loss: 0.2865\n", "178th epoch starting.\n", "Epoch [179] Loss: 0.1723\n", "179th epoch starting.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch [180] Loss: 0.2044\n", "180th epoch starting.\n", "Epoch [181] Loss: 0.4036\n", "181th epoch starting.\n", "Epoch [182] Loss: 0.2581\n", "182th epoch starting.\n", "Epoch [183] Loss: 0.5247\n", "183th epoch starting.\n", "Epoch [184] Loss: 0.0624\n", "184th epoch starting.\n", "Epoch [185] Loss: 0.3808\n", "185th epoch starting.\n", "Epoch [186] Loss: 0.2784\n", "186th epoch starting.\n", "Epoch [187] Loss: 0.2670\n", "187th epoch starting.\n", "Epoch [188] Loss: 0.2350\n", "188th epoch starting.\n", "Epoch [189] Loss: 0.3496\n", "189th epoch starting.\n", "Epoch [190] Loss: 0.1417\n", "190th epoch starting.\n", "Epoch [191] Loss: 0.4387\n", "191th epoch starting.\n", "Epoch [192] Loss: 0.1899\n", "192th epoch starting.\n", "Epoch [193] Loss: 0.3697\n", "193th epoch starting.\n", "Epoch [194] Loss: 0.2307\n", "194th epoch starting.\n", "Epoch [195] Loss: 0.4426\n", "195th epoch starting.\n", "Epoch [196] Loss: 0.2842\n", "196th epoch starting.\n", "Epoch [197] Loss: 0.2145\n", "197th epoch starting.\n", "Epoch [198] Loss: 0.2953\n", "198th epoch starting.\n", "Epoch [199] Loss: 0.2466\n", "199th epoch starting.\n", "Epoch [200] Loss: 0.1592\n", "Time ellapsed in training is: 2501.492520093918\n", "[Test set] Average loss: 0.0049, Accuracy: 8548/10000 (85.48%)\n", "\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "import math\n", "from torch.optim import Optimizer\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets\n", "from torchvision.transforms import transforms\n", "\n", "\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "'''\n", "Step 1:\n", "'''\n", "\n", "# Image preprocessing modules\n", "transform = transforms.Compose([\n", " transforms.Pad(4),\n", " transforms.RandomHorizontalFlip(),\n", " transforms.RandomCrop(32),\n", " transforms.ToTensor()])\n", "\n", "train_dataset = datasets.CIFAR10(root='./cifar_data/',\n", " train=True, \n", " transform=transform,\n", " download=True)\n", "\n", "test_dataset = datasets.CIFAR10(root='./cifar_data/',\n", " train=False, \n", " transform=transforms.ToTensor())\n", " \n", "'''\n", "Step 2\n", "'''\n", "\n", "class VGG13(nn.Module) :\n", " def __init__(self) :\n", " super(VGG13, self).__init__()\n", " \n", " self.conv_layer1 = nn.Sequential(\n", " nn.Conv2d(3, 64, kernel_size=3, padding=1), # 64 * 32 * 32\n", " nn.ReLU(),\n", " nn.Conv2d(64, 64, kernel_size=3, padding=1), # 64 * 32 * 32\n", " nn.ReLU(),\n", " nn.MaxPool2d(kernel_size=2, stride=2) # 64 * 16 * 16\n", " )\n", " self.conv_layer2 = nn.Sequential(\n", " nn.Conv2d(64, 128, kernel_size=3, padding=1), # 128 * 16 * 16\n", " nn.ReLU(),\n", " nn.Conv2d(128, 128, kernel_size=3, padding=1), # 128 * 16 * 16\n", " nn.ReLU(),\n", " nn.MaxPool2d(kernel_size=2, stride=2) # 128 * 8 * 8\n", " )\n", " self.conv_layer3 = nn.Sequential(\n", " nn.Conv2d(128, 256, kernel_size=3, padding=1), # 256 * 8 * 8\n", " nn.ReLU(),\n", " nn.Conv2d(256, 256, kernel_size=3, padding=1), # 256 * 8 * 8\n", " nn.ReLU(),\n", " nn.Conv2d(256, 256, kernel_size=3, padding=1), # 256 * 8 * 8\n", " nn.ReLU(),\n", " nn.MaxPool2d(kernel_size=2, stride=2) # 256 * 4 * 4\n", " )\n", " self.conv_layer4 = nn.Sequential(\n", " nn.Conv2d(256, 512, kernel_size=3, padding=1), # 512 * 4 * 4\n", " nn.ReLU(),\n", " nn.Conv2d(512, 512, kernel_size=3, padding=1), # 512 * 4 * 4\n", " nn.ReLU(),\n", " nn.Conv2d(512, 512, kernel_size=3, padding=1), # 512 * 4 * 4\n", " nn.ReLU(),\n", " nn.MaxPool2d(kernel_size=2, stride=2), # 512 * 2 * 2\n", " )\n", " self.fc_layer1 = nn.Sequential(\n", " nn.Linear(512*2*2, 4096), # 1 * 4096\n", " nn.ReLU(),\n", " nn.Dropout()\n", " )\n", " self.fc_layer2 = nn.Sequential(\n", " nn.Linear(4096, 4096), # 1 * 4096\n", " nn.ReLU(),\n", " nn.Dropout()\n", " )\n", " self.fc_layer3 = nn.Sequential(\n", " nn.Linear(4096, 10), # 1 * num_class\n", " )\n", "\n", " \n", " def forward(self, x) :\n", " output = self.conv_layer1(x)\n", " output = self.conv_layer2(output)\n", " output = self.conv_layer3(output)\n", " output = self.conv_layer4(output)\n", " output = output.view(-1, 512*2*2)\n", " output = self.fc_layer1(output)\n", " output = self.fc_layer2(output)\n", " output = self.fc_layer3(output)\n", " return output\n", "\n", "\n", "'''\n", "Step 3 \n", "'''\n", "model = VGG13().to(device)\n", "\n", "def weights_init(m):\n", " if isinstance(m, nn.Conv2d):\n", " nn.init.kaiming_normal(m.weight.data)\n", " m.bias.data.zero_()\n", "model.apply(weights_init)\n", "\n", "loss_function = torch.nn.CrossEntropyLoss()\n", "\n", "optimizer = torch.optim.SGD(model.parameters(), lr=0.05, momentum=0.9, weight_decay=5e-4)\n", "\n", "'''\n", "Step 4\n", "'''\n", "model.train()\n", "train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)\n", "\n", "import time\n", "start = time.time()\n", "for epoch in range(200) :\n", " print(\"{}th epoch starting.\".format(epoch))\n", " for i, (images, labels) in enumerate(train_loader) :\n", " images, labels = images.to(device), labels.to(device)\n", "\n", " optimizer.zero_grad()\n", " train_loss = loss_function(model(images), labels)\n", " train_loss.backward()\n", "\n", " optimizer.step()\n", "\n", " print (\"Epoch [{}] Loss: {:.4f}\".format(epoch+1, train_loss.item()))\n", "\n", "end = time.time()\n", "print(\"Time ellapsed in training is: {}\".format(end - start))\n", "\n", "\n", "'''\n", "Step 5\n", "'''\n", "model.eval()\n", "test_loss, correct, total = 0, 0, 0\n", "\n", "test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=100, shuffle=False)\n", "with torch.no_grad():\n", " for images, labels in test_loader :\n", " images, labels = images.to(device), labels.to(device)\n", "\n", " output = model(images)\n", " test_loss += loss_function(output, labels).item()\n", "\n", " pred = output.max(1, keepdim=True)[1]\n", " correct += pred.eq(labels.view_as(pred)).sum().item()\n", "\n", " total += labels.size(0)\n", "\n", "print('[Test set] Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\\n'.format(\n", " test_loss /total, correct, total,\n", " 100. * correct / total))" ] }, { "cell_type": "markdown", "id": "4a43fbe3", "metadata": {}, "source": [ "# Autodiff example" ] }, { "cell_type": "code", "execution_count": 6, "id": "e7098482", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(3.3000, requires_grad=True)\n", "tensor(1.1000, requires_grad=True)\n", "tensor(2.3000, requires_grad=True)\n", "0.2924030125141144\n", "tensor(-0.2853)\n", "tensor(-1.0158)\n", "tensor(0.2880)\n" ] } ], "source": [ "import torch\n", "\n", "x = torch.tensor(3.3, requires_grad=True)\n", "y = torch.tensor(1.1, requires_grad=True)\n", "z = torch.tensor(2.3, requires_grad=True)\n", "\n", "fn = torch.sin(torch.cosh(y*y+x/z)+torch.tanh(x*y*z))/torch.log(1+torch.exp(x))\n", "fn.backward()\n", "\n", "print(x)\n", "print(y)\n", "print(z)\n", "print(fn.item())\n", "print(x.grad)\n", "print(y.grad)\n", "print(z.grad)" ] }, { "cell_type": "markdown", "id": "c8904843", "metadata": {}, "source": [ "NiN Network for CIFAR10\n", "\n", "Results: [Test set] Average loss: 0.0045, Accuracy: 8484/10000 (84.84%)" ] }, { "cell_type": "code", "execution_count": 2, "id": "0987afc8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n", "0th epoch starting.\n", "Epoch [1] Loss: 1.8504\n", "1th epoch starting.\n", "Epoch [2] Loss: 1.6397\n", "2th epoch starting.\n", "Epoch [3] Loss: 1.4323\n", "3th epoch starting.\n", "Epoch [4] Loss: 1.5498\n", "4th epoch starting.\n", "Epoch [5] Loss: 1.3615\n", "5th epoch starting.\n", "Epoch [6] Loss: 1.2900\n", "6th epoch starting.\n", "Epoch [7] Loss: 1.4954\n", "7th epoch starting.\n", "Epoch [8] Loss: 1.2542\n", "8th epoch starting.\n", "Epoch [9] Loss: 1.1560\n", "9th epoch starting.\n", "Epoch [10] Loss: 1.1334\n", "10th epoch starting.\n", "Epoch [11] Loss: 1.1016\n", "11th epoch starting.\n", "Epoch [12] Loss: 1.4655\n", "12th epoch starting.\n", "Epoch [13] Loss: 0.9612\n", "13th epoch starting.\n", "Epoch [14] Loss: 0.8573\n", "14th epoch starting.\n", "Epoch [15] Loss: 1.0215\n", "15th epoch starting.\n", "Epoch [16] Loss: 0.9366\n", "16th epoch starting.\n", "Epoch [17] Loss: 1.0479\n", "17th epoch starting.\n", "Epoch [18] Loss: 0.9181\n", "18th epoch starting.\n", "Epoch [19] Loss: 1.0598\n", "19th epoch starting.\n", "Epoch [20] Loss: 1.0103\n", "20th epoch starting.\n", "Epoch [21] Loss: 0.8864\n", "21th epoch starting.\n", "Epoch [22] Loss: 1.0153\n", "22th epoch starting.\n", "Epoch [23] Loss: 0.9894\n", "23th epoch starting.\n", "Epoch [24] Loss: 0.8464\n", "24th epoch starting.\n", "Epoch [25] Loss: 0.9454\n", "25th epoch starting.\n", "Epoch [26] Loss: 0.8065\n", "26th epoch starting.\n", "Epoch [27] Loss: 1.0098\n", "27th epoch starting.\n", "Epoch [28] Loss: 0.8492\n", "28th epoch starting.\n", "Epoch [29] Loss: 0.6086\n", "29th epoch starting.\n", "Epoch [30] Loss: 0.7288\n", "30th epoch starting.\n", "Epoch [31] Loss: 0.8197\n", "31th epoch starting.\n", "Epoch [32] Loss: 0.9679\n", "32th epoch starting.\n", "Epoch [33] Loss: 0.8597\n", "33th epoch starting.\n", "Epoch [34] Loss: 0.7620\n", "34th epoch starting.\n", "Epoch [35] Loss: 0.6547\n", "35th epoch starting.\n", "Epoch [36] Loss: 0.8965\n", "36th epoch starting.\n", "Epoch [37] Loss: 0.8094\n", "37th epoch starting.\n", "Epoch [38] Loss: 0.5900\n", "38th epoch starting.\n", "Epoch [39] Loss: 0.7043\n", "39th epoch starting.\n", "Epoch [40] Loss: 0.5735\n", "40th epoch starting.\n", "Epoch [41] Loss: 0.5885\n", "41th epoch starting.\n", "Epoch [42] Loss: 0.6251\n", "42th epoch starting.\n", "Epoch [43] Loss: 0.7015\n", "43th epoch starting.\n", "Epoch [44] Loss: 0.5272\n", "44th epoch starting.\n", "Epoch [45] Loss: 0.7098\n", "45th epoch starting.\n", "Epoch [46] Loss: 0.7273\n", "46th epoch starting.\n", "Epoch [47] Loss: 0.5358\n", "47th epoch starting.\n", "Epoch [48] Loss: 0.7135\n", "48th epoch starting.\n", "Epoch [49] Loss: 0.6233\n", "49th epoch starting.\n", "Epoch [50] Loss: 0.5925\n", "50th epoch starting.\n", "Epoch [51] Loss: 0.5151\n", "51th epoch starting.\n", "Epoch [52] Loss: 0.6230\n", "52th epoch starting.\n", "Epoch [53] Loss: 0.6385\n", "53th epoch starting.\n", "Epoch [54] Loss: 0.4591\n", "54th epoch starting.\n", "Epoch [55] Loss: 0.6026\n", "55th epoch starting.\n", "Epoch [56] Loss: 0.7560\n", "56th epoch starting.\n", "Epoch [57] Loss: 0.5030\n", "57th epoch starting.\n", "Epoch [58] Loss: 0.4451\n", "58th epoch starting.\n", "Epoch [59] Loss: 0.5055\n", "59th epoch starting.\n", "Epoch [60] Loss: 0.5499\n", "60th epoch starting.\n", "Epoch [61] Loss: 0.6861\n", "61th epoch starting.\n", "Epoch [62] Loss: 0.5584\n", "62th epoch starting.\n", "Epoch [63] Loss: 0.3395\n", "63th epoch starting.\n", "Epoch [64] Loss: 0.5743\n", "64th epoch starting.\n", "Epoch [65] Loss: 0.6627\n", "65th epoch starting.\n", "Epoch [66] Loss: 0.5948\n", "66th epoch starting.\n", "Epoch [67] Loss: 0.4674\n", "67th epoch starting.\n", "Epoch [68] Loss: 0.3242\n", "68th epoch starting.\n", "Epoch [69] Loss: 0.6233\n", "69th epoch starting.\n", "Epoch [70] Loss: 0.4746\n", "70th epoch starting.\n", "Epoch [71] Loss: 0.3657\n", "71th epoch starting.\n", "Epoch [72] Loss: 0.3694\n", "72th epoch starting.\n", "Epoch [73] Loss: 0.4958\n", "73th epoch starting.\n", "Epoch [74] Loss: 0.5924\n", "74th epoch starting.\n", "Epoch [75] Loss: 0.5234\n", "75th epoch starting.\n", "Epoch [76] Loss: 0.5253\n", "76th epoch starting.\n", "Epoch [77] Loss: 0.4795\n", "77th epoch starting.\n", "Epoch [78] Loss: 0.3615\n", "78th epoch starting.\n", "Epoch [79] Loss: 0.4971\n", "79th epoch starting.\n", "Epoch [80] Loss: 0.4529\n", "80th epoch starting.\n", "Epoch [81] Loss: 0.4693\n", "81th epoch starting.\n", "Epoch [82] Loss: 0.4194\n", "82th epoch starting.\n", "Epoch [83] Loss: 0.5083\n", "83th epoch starting.\n", "Epoch [84] Loss: 0.5121\n", "84th epoch starting.\n", "Epoch [85] Loss: 0.5334\n", "85th epoch starting.\n", "Epoch [86] Loss: 0.4529\n", "86th epoch starting.\n", "Epoch [87] Loss: 0.3701\n", "87th epoch starting.\n", "Epoch [88] Loss: 0.4664\n", "88th epoch starting.\n", "Epoch [89] Loss: 0.3212\n", "89th epoch starting.\n", "Epoch [90] Loss: 0.5099\n", "90th epoch starting.\n", "Epoch [91] Loss: 0.6530\n", "91th epoch starting.\n", "Epoch [92] Loss: 0.5174\n", "92th epoch starting.\n", "Epoch [93] Loss: 0.3019\n", "93th epoch starting.\n", "Epoch [94] Loss: 0.4903\n", "94th epoch starting.\n", "Epoch [95] Loss: 0.4617\n", "95th epoch starting.\n", "Epoch [96] Loss: 0.3215\n", "96th epoch starting.\n", "Epoch [97] Loss: 0.3960\n", "97th epoch starting.\n", "Epoch [98] Loss: 0.3650\n", "98th epoch starting.\n", "Epoch [99] Loss: 0.4908\n", "99th epoch starting.\n", "Epoch [100] Loss: 0.3257\n", "Time ellapsed in training is: 1114.2113692760468\n", "[Test set] Average loss: 0.0045, Accuracy: 8484/10000 (84.84%)\n", "\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.optim import Optimizer\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets\n", "from torchvision.transforms import transforms\n", "\n", "\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "'''\n", "Step 1:\n", "'''\n", "\n", "transform = transforms.Compose([\n", " transforms.Pad(4),\n", " transforms.RandomHorizontalFlip(),\n", " transforms.RandomCrop(32),\n", " transforms.ToTensor()])\n", "\n", "train_dataset = datasets.CIFAR10(root='./cifar_10data/',\n", " train=True, \n", " transform=transform,\n", " download=True)\n", "\n", "test_dataset = datasets.CIFAR10(root='./cifar_10data/',\n", " train=False, \n", " transform=transforms.ToTensor())\n", " \n", "'''\n", "Step 2\n", "'''\n", "class NiN(nn.Module) :\n", " def __init__(self) :\n", " super(NiN, self).__init__()\n", " \n", " self.mlpconv_layer1 = nn.Sequential(\n", " nn.Conv2d(3, 192, kernel_size=5, padding=2), # 192 * 32 * 32\n", " nn.ReLU(),\n", " nn.Conv2d(192, 160, kernel_size=1), # 160 * 32 * 32\n", " nn.ReLU(),\n", " nn.Conv2d(160, 96, kernel_size=1), # 96 * 32 * 32\n", " nn.ReLU(),\n", " nn.MaxPool2d(kernel_size=3, stride=2, padding=1),# 96 * 16 * 16\n", " nn.Dropout()\n", " )\n", " self.mlpconv_layer2 = nn.Sequential(\n", " nn.Conv2d(96, 192, kernel_size=5, padding=2), # 192 * 16 * 16\n", " nn.ReLU(),\n", " nn.Conv2d(192, 192, kernel_size=1), # 192 * 16 * 16\n", " nn.ReLU(),\n", " nn.Conv2d(192, 192, kernel_size=1), # 192 * 16 * 16\n", " nn.ReLU(),\n", " nn.MaxPool2d(kernel_size=3, stride=2, padding=1),# 192 * 8 * 8\n", " nn.Dropout()\n", " )\n", " self.mlpconv_layer3 = nn.Sequential(\n", " nn.Conv2d(192, 192, kernel_size=3, padding=1), # 192 * 8 * 8\n", " nn.ReLU(),\n", " nn.Conv2d(192, 192, kernel_size=1), # 192 * 8 * 8\n", " nn.ReLU(),\n", " nn.Conv2d(192, 10, kernel_size=1), # 10 * 8 * 8\n", " nn.ReLU(),\n", " nn.AvgPool2d(kernel_size=8) # 10 * 1 * 1\n", " )\n", " \n", " \n", " def forward(self, x) :\n", " output = self.mlpconv_layer1(x)\n", " output = self.mlpconv_layer2(output)\n", " output = self.mlpconv_layer3(output)\n", " output = output.view(-1, 10)\n", " return output\n", "\n", "\n", "'''\n", "Step 3\n", "'''\n", "model = NiN().to(device)\n", "loss_function = torch.nn.CrossEntropyLoss()\n", "optimizer = torch.optim.Adam(model.parameters(), lr=0.0003, weight_decay=0.00001)\n", "\n", "\n", "'''\n", "Step 4\n", "'''\n", "model.train()\n", "train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)\n", "\n", "import time\n", "start = time.time()\n", "for epoch in range(100) :\n", " print(\"{}th epoch starting.\".format(epoch))\n", " for i, (images, labels) in enumerate(train_loader) :\n", " images, labels = images.to(device), labels.to(device)\n", "\n", " optimizer.zero_grad()\n", " train_loss = loss_function(model(images), labels)\n", " train_loss.backward()\n", "\n", " optimizer.step()\n", "\n", " print (\"Epoch [{}] Loss: {:.4f}\".format(epoch+1, train_loss.item()))\n", "\n", "end = time.time()\n", "print(\"Time ellapsed in training is: {}\".format(end - start))\n", "\n", "\n", "'''\n", "Step 5\n", "'''\n", "model.eval()\n", "test_loss, correct, total = 0, 0, 0\n", "\n", "test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=100, shuffle=False)\n", "with torch.no_grad():\n", " for images, labels in test_loader :\n", " images, labels = images.to(device), labels.to(device)\n", "\n", " output = model(images)\n", " test_loss += loss_function(output, labels).item()\n", "\n", " pred = output.max(1, keepdim=True)[1]\n", " correct += pred.eq(labels.view_as(pred)).sum().item()\n", "\n", " total += labels.size(0)\n", "\n", "print('[Test set] Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\\n'.format(\n", " test_loss /total, correct, total,\n", " 100. * correct / total))" ] }, { "cell_type": "markdown", "id": "f0fb1a24", "metadata": {}, "source": [ "GoogLeNet for CIFAR10\n", "\n", "Results: [Test set] Average loss: 0.0045, Accuracy: 8728/10000 (87.28%)\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "17b81100", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n", "0th epoch starting.\n", "Epoch [1] Loss: 2.2639\n", "1th epoch starting.\n", "Epoch [2] Loss: 1.9377\n", "2th epoch starting.\n", "Epoch [3] Loss: 1.8957\n", "3th epoch starting.\n", "Epoch [4] Loss: 1.8470\n", "4th epoch starting.\n", "Epoch [5] Loss: 1.9300\n", "5th epoch starting.\n", "Epoch [6] Loss: 1.6088\n", "6th epoch starting.\n", "Epoch [7] Loss: 1.7900\n", "7th epoch starting.\n", "Epoch [8] Loss: 1.5765\n", "8th epoch starting.\n", "Epoch [9] Loss: 1.6309\n", "9th epoch starting.\n", "Epoch [10] Loss: 1.3813\n", "10th epoch starting.\n", "Epoch [11] Loss: 1.2910\n", "11th epoch starting.\n", "Epoch [12] Loss: 1.4910\n", "12th epoch starting.\n", "Epoch [13] Loss: 1.2687\n", "13th epoch starting.\n", "Epoch [14] Loss: 1.1695\n", "14th epoch starting.\n", "Epoch [15] Loss: 1.4188\n", "15th epoch starting.\n", "Epoch [16] Loss: 1.2343\n", "16th epoch starting.\n", "Epoch [17] Loss: 1.3205\n", "17th epoch starting.\n", "Epoch [18] Loss: 1.4026\n", "18th epoch starting.\n", "Epoch [19] Loss: 0.9462\n", "19th epoch starting.\n", "Epoch [20] Loss: 0.8126\n", "20th epoch starting.\n", "Epoch [21] Loss: 0.9225\n", "21th epoch starting.\n", "Epoch [22] Loss: 0.9850\n", "22th epoch starting.\n", "Epoch [23] Loss: 1.1631\n", "23th epoch starting.\n", "Epoch [24] Loss: 1.0256\n", "24th epoch starting.\n", "Epoch [25] Loss: 0.9437\n", "25th epoch starting.\n", "Epoch [26] Loss: 0.9535\n", "26th epoch starting.\n", "Epoch [27] Loss: 0.9194\n", "27th epoch starting.\n", "Epoch [28] Loss: 1.1363\n", "28th epoch starting.\n", "Epoch [29] Loss: 0.9299\n", "29th epoch starting.\n", "Epoch [30] Loss: 0.8454\n", "30th epoch starting.\n", "Epoch [31] Loss: 0.9446\n", "31th epoch starting.\n", "Epoch [32] Loss: 0.8158\n", "32th epoch starting.\n", "Epoch [33] Loss: 0.6574\n", "33th epoch starting.\n", "Epoch [34] Loss: 0.8693\n", "34th epoch starting.\n", "Epoch [35] Loss: 0.6481\n", "35th epoch starting.\n", "Epoch [36] Loss: 0.6801\n", "36th epoch starting.\n", "Epoch [37] Loss: 0.7408\n", "37th epoch starting.\n", "Epoch [38] Loss: 0.5675\n", "38th epoch starting.\n", "Epoch [39] Loss: 0.4982\n", "39th epoch starting.\n", "Epoch [40] Loss: 0.9392\n", "40th epoch starting.\n", "Epoch [41] Loss: 0.6104\n", "41th epoch starting.\n", "Epoch [42] Loss: 0.5620\n", "42th epoch starting.\n", "Epoch [43] Loss: 0.5158\n", "43th epoch starting.\n", "Epoch [44] Loss: 0.5305\n", "44th epoch starting.\n", "Epoch [45] Loss: 0.5582\n", "45th epoch starting.\n", "Epoch [46] Loss: 0.5192\n", "46th epoch starting.\n", "Epoch [47] Loss: 0.6213\n", "47th epoch starting.\n", "Epoch [48] Loss: 0.5972\n", "48th epoch starting.\n", "Epoch [49] Loss: 0.5634\n", "49th epoch starting.\n", "Epoch [50] Loss: 0.5685\n", "50th epoch starting.\n", "Epoch [51] Loss: 0.3920\n", "51th epoch starting.\n", "Epoch [52] Loss: 0.6400\n", "52th epoch starting.\n", "Epoch [53] Loss: 0.5925\n", "53th epoch starting.\n", "Epoch [54] Loss: 0.3719\n", "54th epoch starting.\n", "Epoch [55] Loss: 0.4549\n", "55th epoch starting.\n", "Epoch [56] Loss: 0.4853\n", "56th epoch starting.\n", "Epoch [57] Loss: 0.4016\n", "57th epoch starting.\n", "Epoch [58] Loss: 0.4569\n", "58th epoch starting.\n", "Epoch [59] Loss: 0.2845\n", "59th epoch starting.\n", "Epoch [60] Loss: 0.2335\n", "60th epoch starting.\n", "Epoch [61] Loss: 0.3346\n", "61th epoch starting.\n", "Epoch [62] Loss: 0.4178\n", "62th epoch starting.\n", "Epoch [63] Loss: 0.4807\n", "63th epoch starting.\n", "Epoch [64] Loss: 0.3539\n", "64th epoch starting.\n", "Epoch [65] Loss: 0.2574\n", "65th epoch starting.\n", "Epoch [66] Loss: 0.2346\n", "66th epoch starting.\n", "Epoch [67] Loss: 0.3409\n", "67th epoch starting.\n", "Epoch [68] Loss: 0.2400\n", "68th epoch starting.\n", "Epoch [69] Loss: 0.2105\n", "69th epoch starting.\n", "Epoch [70] Loss: 0.3277\n", "70th epoch starting.\n", "Epoch [71] Loss: 0.3122\n", "71th epoch starting.\n", "Epoch [72] Loss: 0.2584\n", "72th epoch starting.\n", "Epoch [73] Loss: 0.3545\n", "73th epoch starting.\n", "Epoch [74] Loss: 0.2783\n", "74th epoch starting.\n", "Epoch [75] Loss: 0.2108\n", "75th epoch starting.\n", "Epoch [76] Loss: 0.2256\n", "76th epoch starting.\n", "Epoch [77] Loss: 0.2780\n", "77th epoch starting.\n", "Epoch [78] Loss: 0.1460\n", "78th epoch starting.\n", "Epoch [79] Loss: 0.1612\n", "79th epoch starting.\n", "Epoch [80] Loss: 0.2963\n", "80th epoch starting.\n", "Epoch [81] Loss: 0.3506\n", "81th epoch starting.\n", "Epoch [82] Loss: 0.1634\n", "82th epoch starting.\n", "Epoch [83] Loss: 0.1594\n", "83th epoch starting.\n", "Epoch [84] Loss: 0.2915\n", "84th epoch starting.\n", "Epoch [85] Loss: 0.2295\n", "85th epoch starting.\n", "Epoch [86] Loss: 0.0733\n", "86th epoch starting.\n", "Epoch [87] Loss: 0.1349\n", "87th epoch starting.\n", "Epoch [88] Loss: 0.1256\n", "88th epoch starting.\n", "Epoch [89] Loss: 0.2889\n", "89th epoch starting.\n", "Epoch [90] Loss: 0.2816\n", "90th epoch starting.\n", "Epoch [91] Loss: 0.3078\n", "91th epoch starting.\n", "Epoch [92] Loss: 0.6411\n", "92th epoch starting.\n", "Epoch [93] Loss: 0.3656\n", "93th epoch starting.\n", "Epoch [94] Loss: 0.2338\n", "94th epoch starting.\n", "Epoch [95] Loss: 0.2020\n", "95th epoch starting.\n", "Epoch [96] Loss: 0.2832\n", "96th epoch starting.\n", "Epoch [97] Loss: 0.3860\n", "97th epoch starting.\n", "Epoch [98] Loss: 0.1647\n", "98th epoch starting.\n", "Epoch [99] Loss: 0.1763\n", "99th epoch starting.\n", "Epoch [100] Loss: 0.1732\n", "Time ellapsed in training is: 7174.51455116272\n", "[Test set] Average loss: 0.0045, Accuracy: 8728/10000 (87.28%)\n", "\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.optim import Optimizer\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets\n", "from torchvision.transforms import transforms\n", "\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "'''\n", "Step 1:\n", "'''\n", "transform = transforms.Compose([\n", " transforms.Pad(4),\n", " transforms.RandomHorizontalFlip(),\n", " transforms.RandomCrop(32),\n", " transforms.ToTensor()])\n", "\n", "train_dataset = datasets.CIFAR10(root='./cifar_10data/',\n", " train=True, \n", " transform=transform,\n", " download=True)\n", "\n", "test_dataset = datasets.CIFAR10(root='./cifar_10data/',\n", " train=False, \n", " transform=transforms.ToTensor())\n", " \n", "'''\n", "Step 2:\n", "'''\n", "class GoogLeNet(nn.Module):\n", "\n", " def __init__(self):\n", " super(GoogLeNet, self).__init__() \n", "\n", " self.conv1 = BasicConv2d(3, 64, kernel_size=7, padding=3)\n", " self.conv2 = BasicConv2d(64, 64, kernel_size=1)\n", " self.conv3 = BasicConv2d(64, 192, kernel_size=5)\n", "\n", " self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)\n", " self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)\n", " self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)\n", "\n", " self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)\n", " self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)\n", " self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)\n", " self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)\n", " self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)\n", " self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)\n", "\n", " self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)\n", " self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)\n", "\n", "\n", " self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n", " self.dropout = nn.Dropout(0.4)\n", " self.fc = nn.Linear(1024, 10)\n", " \n", " for m in self.modules():\n", " if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):\n", " torch.nn.init.xavier_uniform_(m.weight)\n", "\n", " def forward(self, x):\n", " # N x 3 x 32 x 32\n", " x = self.conv1(x)\n", " # N x 64 x 32 x 32\n", " x = self.conv2(x)\n", " # N x 64 x 32 x 32\n", " x = self.conv3(x)\n", " # N x 192 x 28 x 28\n", "\n", " # N x 192 x 28 x 28\n", " x = self.inception3a(x)\n", " # N x 256 x 28 x 28\n", " x = self.inception3b(x)\n", " # N x 480 x 28 x 28\n", " x = self.maxpool3(x)\n", " # N x 480 x 14 x 14\n", " x = self.inception4a(x)\n", " # N x 512 x 14 x 14\n", " x = self.inception4b(x)\n", " # N x 512 x 14 x 14\n", " x = self.inception4c(x)\n", " # N x 512 x 14 x 14\n", " x = self.inception4d(x)\n", " # N x 528 x 14 x 14\n", " x = self.inception4e(x)\n", " # N x 832 x 14 x 14\n", " x = self.maxpool4(x)\n", " # N x 832 x 7 x 7\n", " x = self.inception5a(x)\n", " # N x 832 x 7 x 7\n", " x = self.inception5b(x)\n", " # N x 1024 x 7 x 7\n", " x = self.avgpool(x)\n", " # N x 1024 x 1 x 1\n", " x = torch.flatten(x, 1)\n", " # N x 1024\n", " x = self.dropout(x)\n", " x = self.fc(x)\n", " # N x 10 (num_classes)\n", " return x\n", "\n", "class Inception(nn.Module):\n", "\n", " def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):\n", " super(Inception, self).__init__()\n", " \n", " self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)\n", "\n", " self.branch2 = nn.Sequential(\n", " BasicConv2d(in_channels, ch3x3red, kernel_size=1),\n", " BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)\n", " )\n", "\n", " self.branch3 = nn.Sequential(\n", " BasicConv2d(in_channels, ch5x5red, kernel_size=1),\n", " BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2)\n", " )\n", "\n", " self.branch4 = nn.Sequential(\n", " nn.MaxPool2d(kernel_size=3, stride=1, padding=1),\n", " BasicConv2d(in_channels, pool_proj, kernel_size=1)\n", " )\n", "\n", " def forward(self, x):\n", " branch1 = self.branch1(x)\n", " branch2 = self.branch2(x)\n", " branch3 = self.branch3(x)\n", " branch4 = self.branch4(x)\n", "\n", " return torch.cat([branch1, branch2, branch3, branch4], 1)\n", "\n", "\n", "class BasicConv2d(nn.Module):\n", " def __init__(self, in_channels, out_channels, **kwargs):\n", " super(BasicConv2d, self).__init__()\n", " self.conv = nn.Sequential(\n", " nn.Conv2d(in_channels, out_channels, **kwargs),\n", " nn.ReLU()\n", " )\n", " def forward(self, x):\n", " return self.conv(x)\n", " \n", "\n", "\n", "'''\n", "Step 3\n", "'''\n", "model = GoogLeNet().to(device)\n", "loss_function = torch.nn.CrossEntropyLoss()\n", "optimizer = torch.optim.SGD(model.parameters(), lr=0.1, weight_decay=5e-4)\n", "\n", "'''\n", "Step 4\n", "'''\n", "model.train()\n", "train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100, shuffle=True)\n", "\n", "import time\n", "start = time.time()\n", "for epoch in range(100) :\n", " print(\"{}th epoch starting.\".format(epoch))\n", " for i, (images, labels) in enumerate(train_loader) :\n", " images, labels = images.to(device), labels.to(device)\n", "\n", " optimizer.zero_grad()\n", " train_loss = loss_function(model(images), labels)\n", " train_loss.backward()\n", "\n", " optimizer.step()\n", "\n", " print (\"Epoch [{}] Loss: {:.4f}\".format(epoch+1, train_loss.item()))\n", "\n", "end = time.time()\n", "print(\"Time ellapsed in training is: {}\".format(end - start))\n", "\n", "\n", "'''\n", "Step 5\n", "'''\n", "model.eval()\n", "test_loss, correct, total = 0, 0, 0\n", "\n", "test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=100, shuffle=False)\n", "with torch.no_grad(): #using context manager\n", " for images, labels in test_loader :\n", " images, labels = images.to(device), labels.to(device)\n", "\n", " output = model(images)\n", " test_loss += loss_function(output, labels).item()\n", "\n", " pred = output.max(1, keepdim=True)[1]\n", " correct += pred.eq(labels.view_as(pred)).sum().item()\n", "\n", " total += labels.size(0)\n", "\n", "print('[Test set] Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\\n'.format(\n", " test_loss /total, correct, total,\n", " 100. * correct / total))" ] }, { "cell_type": "markdown", "id": "1e8f540f", "metadata": {}, "source": [ "GoogLeNet with BatchNorm CIFAR10\n", "\n", "\n", "Results: [Test set] Average loss: 0.0072, Accuracy: 8222/10000 (82.22%)" ] }, { "cell_type": "code", "execution_count": 1, "id": "c393ec38", "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n", "0th epoch starting.\n", "Epoch [1] Loss: 1.3556\n", "1th epoch starting.\n", "Epoch [2] Loss: 1.0016\n", "2th epoch starting.\n", "Epoch [3] Loss: 0.8536\n", "3th epoch starting.\n", "Epoch [4] Loss: 0.6516\n", "4th epoch starting.\n", "Epoch [5] Loss: 0.4868\n", "5th epoch starting.\n", "Epoch [6] Loss: 0.3415\n", "6th epoch starting.\n", "Epoch [7] Loss: 0.3197\n", "7th epoch starting.\n", "Epoch [8] Loss: 0.3560\n", "8th epoch starting.\n", "Epoch [9] Loss: 0.3828\n", "9th epoch starting.\n", "Epoch [10] Loss: 0.4011\n", "10th epoch starting.\n", "Epoch [11] Loss: 0.3260\n", "11th epoch starting.\n", "Epoch [12] Loss: 0.2618\n", "12th epoch starting.\n", "Epoch [13] Loss: 0.2623\n", "13th epoch starting.\n", "Epoch [14] Loss: 0.2755\n", "14th epoch starting.\n", "Epoch [15] Loss: 0.3459\n", "15th epoch starting.\n", "Epoch [16] Loss: 0.2778\n", "16th epoch starting.\n", "Epoch [17] Loss: 0.2131\n", "17th epoch starting.\n", "Epoch [18] Loss: 0.3055\n", "18th epoch starting.\n", "Epoch [19] Loss: 0.1537\n", "19th epoch starting.\n", "Epoch [20] Loss: 0.0999\n", "20th epoch starting.\n", "Epoch [21] Loss: 0.2503\n", "21th epoch starting.\n", "Epoch [22] Loss: 0.1957\n", "22th epoch starting.\n", "Epoch [23] Loss: 0.1567\n", "23th epoch starting.\n", "Epoch [24] Loss: 0.2063\n", "24th epoch starting.\n", "Epoch [25] Loss: 0.0958\n", "25th epoch starting.\n", "Epoch [26] Loss: 0.0811\n", "26th epoch starting.\n", "Epoch [27] Loss: 0.1389\n", "27th epoch starting.\n", "Epoch [28] Loss: 0.1160\n", "28th epoch starting.\n", "Epoch [29] Loss: 0.2026\n", "29th epoch starting.\n", "Epoch [30] Loss: 0.1244\n", "30th epoch starting.\n", "Epoch [31] Loss: 0.0889\n", "31th epoch starting.\n", "Epoch [32] Loss: 0.0641\n", "32th epoch starting.\n", "Epoch [33] Loss: 0.1299\n", "33th epoch starting.\n", "Epoch [34] Loss: 0.1237\n", "34th epoch starting.\n", "Epoch [35] Loss: 0.0647\n", "35th epoch starting.\n", "Epoch [36] Loss: 0.1111\n", "36th epoch starting.\n", "Epoch [37] Loss: 0.1146\n", "37th epoch starting.\n", "Epoch [38] Loss: 0.0665\n", "38th epoch starting.\n", "Epoch [39] Loss: 0.1160\n", "39th epoch starting.\n", "Epoch [40] Loss: 0.1581\n", "40th epoch starting.\n", "Epoch [41] Loss: 0.1007\n", "41th epoch starting.\n", "Epoch [42] Loss: 0.0497\n", "42th epoch starting.\n", "Epoch [43] Loss: 0.0371\n", "43th epoch starting.\n", "Epoch [44] Loss: 0.0768\n", "44th epoch starting.\n", "Epoch [45] Loss: 0.0543\n", "45th epoch starting.\n", "Epoch [46] Loss: 0.0976\n", "46th epoch starting.\n", "Epoch [47] Loss: 0.0502\n", "47th epoch starting.\n", "Epoch [48] Loss: 0.1551\n", "48th epoch starting.\n", "Epoch [49] Loss: 0.0827\n", "49th epoch starting.\n", "Epoch [50] Loss: 0.1055\n", "50th epoch starting.\n", "Epoch [51] Loss: 0.1577\n", "51th epoch starting.\n", "Epoch [52] Loss: 0.0821\n", "52th epoch starting.\n", "Epoch [53] Loss: 0.0847\n", "53th epoch starting.\n", "Epoch [54] Loss: 0.1028\n", "54th epoch starting.\n", "Epoch [55] Loss: 0.0845\n", "55th epoch starting.\n", "Epoch [56] Loss: 0.0814\n", "56th epoch starting.\n", "Epoch [57] Loss: 0.0833\n", "57th epoch starting.\n", "Epoch [58] Loss: 0.1108\n", "58th epoch starting.\n", "Epoch [59] Loss: 0.1267\n", "59th epoch starting.\n", "Epoch [60] Loss: 0.1371\n", "60th epoch starting.\n", "Epoch [61] Loss: 0.0947\n", "61th epoch starting.\n", "Epoch [62] Loss: 0.1374\n", "62th epoch starting.\n", "Epoch [63] Loss: 0.0665\n", "63th epoch starting.\n", "Epoch [64] Loss: 0.1066\n", "64th epoch starting.\n", "Epoch [65] Loss: 0.0788\n", "65th epoch starting.\n", "Epoch [66] Loss: 0.0944\n", "66th epoch starting.\n", "Epoch [67] Loss: 0.0541\n", "67th epoch starting.\n", "Epoch [68] Loss: 0.1057\n", "68th epoch starting.\n", "Epoch [69] Loss: 0.1278\n", "69th epoch starting.\n", "Epoch [70] Loss: 0.0634\n", "70th epoch starting.\n", "Epoch [71] Loss: 0.0315\n", "71th epoch starting.\n", "Epoch [72] Loss: 0.1361\n", "72th epoch starting.\n", "Epoch [73] Loss: 0.0487\n", "73th epoch starting.\n", "Epoch [74] Loss: 0.1465\n", "74th epoch starting.\n", "Epoch [75] Loss: 0.0442\n", "75th epoch starting.\n", "Epoch [76] Loss: 0.0595\n", "76th epoch starting.\n", "Epoch [77] Loss: 0.1443\n", "77th epoch starting.\n", "Epoch [78] Loss: 0.0728\n", "78th epoch starting.\n", "Epoch [79] Loss: 0.0846\n", "79th epoch starting.\n", "Epoch [80] Loss: 0.0924\n", "80th epoch starting.\n", "Epoch [81] Loss: 0.0924\n", "81th epoch starting.\n", "Epoch [82] Loss: 0.0758\n", "82th epoch starting.\n", "Epoch [83] Loss: 0.0511\n", "83th epoch starting.\n", "Epoch [84] Loss: 0.1416\n", "84th epoch starting.\n", "Epoch [85] Loss: 0.0999\n", "85th epoch starting.\n", "Epoch [86] Loss: 0.1032\n", "86th epoch starting.\n", "Epoch [87] Loss: 0.0479\n", "87th epoch starting.\n", "Epoch [88] Loss: 0.0925\n", "88th epoch starting.\n", "Epoch [89] Loss: 0.0760\n", "89th epoch starting.\n", "Epoch [90] Loss: 0.0498\n", "90th epoch starting.\n", "Epoch [91] Loss: 0.0964\n", "91th epoch starting.\n", "Epoch [92] Loss: 0.0629\n", "92th epoch starting.\n", "Epoch [93] Loss: 0.0651\n", "93th epoch starting.\n", "Epoch [94] Loss: 0.0354\n", "94th epoch starting.\n", "Epoch [95] Loss: 0.1460\n", "95th epoch starting.\n", "Epoch [96] Loss: 0.1071\n", "96th epoch starting.\n", "Epoch [97] Loss: 0.0803\n", "97th epoch starting.\n", "Epoch [98] Loss: 0.0751\n", "98th epoch starting.\n", "Epoch [99] Loss: 0.0906\n", "99th epoch starting.\n", "Epoch [100] Loss: 0.0803\n", "Time ellapsed in training is: 8126.175984621048\n", "[Test set] Average loss: 0.0072, Accuracy: 8222/10000 (82.22%)\n", "\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.optim import Optimizer\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets\n", "from torchvision.transforms import transforms\n", "\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "'''\n", "Step 1:\n", "'''\n", "transform = transforms.Compose([\n", " transforms.Pad(4),\n", " transforms.RandomHorizontalFlip(),\n", " transforms.RandomCrop(32),\n", " transforms.ToTensor()])\n", "\n", "train_dataset = datasets.CIFAR10(root='./cifar_10data/',\n", " train=True, \n", " transform=transform,\n", " download=True)\n", "\n", "test_dataset = datasets.CIFAR10(root='./cifar_10data/',\n", " train=False, \n", " transform=transforms.ToTensor())\n", " \n", "'''\n", "Step 2:\n", "'''\n", "class GoogLeNet(nn.Module):\n", "\n", " def __init__(self):\n", " super(GoogLeNet, self).__init__() \n", "\n", " self.conv1 = BasicConv2d(3, 64, kernel_size=7, padding=3)\n", " self.conv2 = BasicConv2d(64, 64, kernel_size=1)\n", " self.conv3 = BasicConv2d(64, 192, kernel_size=5)\n", "\n", " self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)\n", " self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)\n", " self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)\n", "\n", " self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)\n", " self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)\n", " self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)\n", " self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)\n", " self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)\n", " self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)\n", "\n", " self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)\n", " self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)\n", "\n", "\n", " self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n", " self.dropout = nn.Dropout(0.4)\n", " self.fc = nn.Linear(1024, 10)\n", " \n", " for m in self.modules():\n", " if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):\n", " torch.nn.init.xavier_uniform_(m.weight)\n", "\n", " def forward(self, x):\n", " x = self.conv1(x)\n", " x = self.conv2(x)\n", " x = self.conv3(x)\n", "\n", " x = self.inception3a(x)\n", " x = self.inception3b(x)\n", " x = self.maxpool3(x)\n", "\n", " x = self.inception4a(x)\n", " x = self.inception4b(x)\n", " x = self.inception4c(x)\n", " x = self.inception4d(x)\n", " x = self.inception4e(x)\n", " x = self.maxpool4(x)\n", "\n", " x = self.inception5a(x)\n", " x = self.inception5b(x)\n", " x = self.avgpool(x)\n", " \n", " x = torch.flatten(x, 1)\n", " x = self.dropout(x)\n", " x = self.fc(x)\n", " return x\n", "\n", "class Inception(nn.Module):\n", "\n", " def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):\n", " super(Inception, self).__init__()\n", " \n", " self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)\n", "\n", " self.branch2 = nn.Sequential(\n", " BasicConv2d(in_channels, ch3x3red, kernel_size=1),\n", " BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)\n", " )\n", "\n", " self.branch3 = nn.Sequential(\n", " BasicConv2d(in_channels, ch5x5red, kernel_size=1),\n", " BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2)\n", " )\n", "\n", " self.branch4 = nn.Sequential(\n", " nn.MaxPool2d(kernel_size=3, stride=1, padding=1),\n", " BasicConv2d(in_channels, pool_proj, kernel_size=1)\n", " )\n", "\n", " def forward(self, x):\n", " branch1 = self.branch1(x)\n", " branch2 = self.branch2(x)\n", " branch3 = self.branch3(x)\n", " branch4 = self.branch4(x)\n", "\n", " return torch.cat([branch1, branch2, branch3, branch4], 1)\n", "\n", "\n", "class BasicConv2d(nn.Module):\n", " def __init__(self, in_channels, out_channels, **kwargs):\n", " super(BasicConv2d, self).__init__()\n", " self.conv = nn.Sequential(\n", " nn.Conv2d(in_channels, out_channels, **kwargs),\n", " nn.BatchNorm2d(out_channels), #Batch norm here\n", " nn.ReLU()\n", " )\n", " def forward(self, x):\n", " return self.conv(x)\n", " \n", "\n", "\n", "'''\n", "Step 3\n", "'''\n", "model = GoogLeNet().to(device)\n", "loss_function = torch.nn.CrossEntropyLoss()\n", "optimizer = torch.optim.SGD(model.parameters(), lr=0.1, weight_decay=5e-4)\n", "\n", "'''\n", "Step 4\n", "'''\n", "model.train()\n", "train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100, shuffle=True)\n", "\n", "import time\n", "start = time.time()\n", "for epoch in range(100) :\n", " print(\"{}th epoch starting.\".format(epoch))\n", " for i, (images, labels) in enumerate(train_loader) :\n", " images, labels = images.to(device), labels.to(device)\n", "\n", " optimizer.zero_grad()\n", " train_loss = loss_function(model(images), labels)\n", " train_loss.backward()\n", "\n", " optimizer.step()\n", "\n", " print (\"Epoch [{}] Loss: {:.4f}\".format(epoch+1, train_loss.item()))\n", "\n", "end = time.time()\n", "print(\"Time ellapsed in training is: {}\".format(end - start))\n", "\n", "\n", "'''\n", "Step 5\n", "'''\n", "model.eval()\n", "test_loss, correct, total = 0, 0, 0\n", "\n", "test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=100, shuffle=False)\n", "with torch.no_grad(): #using context manager\n", " for images, labels in test_loader :\n", " images, labels = images.to(device), labels.to(device)\n", "\n", " output = model(images)\n", " test_loss += loss_function(output, labels).item()\n", "\n", " pred = output.max(1, keepdim=True)[1]\n", " correct += pred.eq(labels.view_as(pred)).sum().item()\n", "\n", " total += labels.size(0)\n", "\n", "print('[Test set] Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\\n'.format(\n", " test_loss /total, correct, total,\n", " 100. * correct / total))" ] }, { "cell_type": "markdown", "id": "d5e4b574", "metadata": {}, "source": [ "# ResNet" ] }, { "cell_type": "code", "execution_count": null, "id": "77da4776", "metadata": {}, "outputs": [], "source": [ "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "\n", "class ResBlock(nn.Module):\n", " def __init__(self, in_channels, out_channels, stride=1):\n", " super(ResBlock, self).__init__()\n", " self.conv1 = nn.Sequential(\n", " nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),\n", " nn.BatchNorm2d(out_channels),\n", " nn.ReLU()\n", " )\n", " self.conv2 = nn.Sequential(\n", " nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),\n", " nn.BatchNorm2d(out_channels)\n", " )\n", "\n", " self.downsample = None\n", " if stride != 1 or in_channels != out_channels:\n", " self.downsample = nn.Sequential(\n", " nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),\n", " nn.BatchNorm2d(out_channels)\n", " )\n", "\n", " def forward(self, x):\n", " out = self.conv1(x)\n", " out = self.conv2(out)\n", "\n", " residual = x\n", " if self.downsample is not None:\n", " residual = self.downsample(residual)\n", "\n", " return F.relu(out + residual)\n", "\n", "\n", "class ResNet(nn.Module):\n", " def __init__(self, num_blocks, num_classes=10):\n", " super(ResNet, self).__init__()\n", "\n", " self.layer1 = nn.Sequential(\n", " nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False),\n", " nn.BatchNorm2d(16)\n", " )\n", "\n", " self.layer2 = [ResBlock(16, 16)]\n", " for i in range(num_blocks[0] - 1):\n", " self.layer2.append(ResBlock(16, 16))\n", " self.layer2 = nn.ModuleList(self.layer2)\n", "\n", " self.layer3 = [ResBlock(16, 32, 2)]\n", " for i in range(num_blocks[1] - 1):\n", " self.layer3.append(ResBlock(32, 32))\n", " self.layer3 = nn.ModuleList(self.layer3)\n", "\n", " self.layer4 = [ResBlock(32, 64, 2)]\n", " for i in range(num_blocks[2] - 1):\n", " self.layer4.append(ResBlock(64, 64))\n", " self.layer4 = nn.ModuleList(self.layer4)\n", "\n", " self.avgpool = nn.AvgPool2d(8)\n", " self.fc = nn.Linear(64, num_classes)\n", "\n", " self.relu = nn.ReLU(inplace=True)\n", "\n", " def forward(self, x):\n", " out = self.layer1(x)\n", " out = self.relu(out)\n", "\n", " for i in range(len(self.layer2)):\n", " out = self.layer2[i](out)\n", "\n", " for i in range(len(self.layer3)):\n", " out = self.layer3[i](out)\n", "\n", " for i in range(len(self.layer4)):\n", " out = self.layer4[i](out)\n", "\n", " out = self.avgpool(out)\n", " out = nn.Flatten()(out)\n", " out = self.fc(out)\n", "\n", " return out\n", "\n", "\n", "def resnet20():\n", " return ResNet([3, 3, 3])\n", "\n", "\n", "def resnet32():\n", " return ResNet([5, 5, 5])\n", "\n", "\n", "def resnet44():\n", " return ResNet([7, 7, 7])\n", "\n", "\n", "def resnet56():\n", " return ResNet([9, 9, 9])\n", "\n", "\n", "\n", "import torch\n", "from torchvision.datasets import CIFAR10, CIFAR100\n", "import torchvision.transforms as T\n", "from torch.utils.data import DataLoader\n", "import time\n", "# from resnet import resnet56\n", "\n", "\n", "train_transform = T.Compose([\n", " T.RandomCrop(size=32, padding=4),\n", " T.RandomHorizontalFlip(),\n", " T.ToTensor(),\n", " T.Normalize(\n", " mean=[0.485, 0.456, 0.406],\n", " std=[0.229, 0.224, 0.225]\n", " )\n", "])\n", "\n", "val_transform = T.Compose([\n", " T.ToTensor(),\n", " T.Normalize(\n", " mean=[0.485, 0.456, 0.406],\n", " std=[0.229, 0.224, 0.225]\n", " )\n", "])\n", "\n", "try:\n", " train_ds = CIFAR10(root='./', train=True, transform=train_transform,\n", " download=False)\n", "except:\n", " train_ds = CIFAR10(root='./', train=True, transform=train_transform,\n", " download=True)\n", "\n", "train_dl = DataLoader(train_ds, batch_size=128, shuffle=True)\n", "\n", "try:\n", " val_ds = CIFAR10(root='./', train=False, transform=val_transform,\n", " download=False)\n", "except:\n", " val_ds = CIFAR10(root='./', train=False, transform=val_transform,\n", " download=True)\n", "\n", "val_dl = DataLoader(val_ds, batch_size=128)\n", "\n", "print(f\"Total {len(train_ds)} Training Data\")\n", "print(f\"Total {len(val_ds)} Validation Data\")\n", "\n", "\n", "\n", "model = resnet56()\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "model = model.to(device)\n", "\n", "optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0001)\n", "\n", "lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.1)\n", "\n", "criterion = torch.nn.CrossEntropyLoss()\n", "\n", "epochs = 150\n", "\n", "best_acc = 0.\n", "total_time = 0\n", "\n", "for epoch in range(epochs):\n", " tick = time.time()\n", " model.train()\n", " epoch_loss = 0.\n", " if epoch == 99 or epoch == 124:\n", " lr_scheduler.step()\n", " for data in train_dl:\n", " optimizer.zero_grad()\n", " img, label = data[0].to(device), data[1].to(device)\n", "\n", " pred = model(img)\n", "\n", " loss = criterion(pred, label)\n", " loss.backward()\n", " optimizer.step()\n", "\n", " epoch_loss += loss.item()\n", "\n", " print(f\"\\nEpoch {epoch + 1:4d} Train Loss: {epoch_loss:.6f}\")\n", "\n", " model.eval()\n", " correct = 0.\n", " with torch.no_grad():\n", " for data in val_dl:\n", " img, label = data[0].to(device), data[1].to(device)\n", "\n", " pred = model(img)\n", " pred = torch.argmax(pred.data, 1)\n", " correct += (pred == label).sum().item()\n", "\n", " print(f\"Epoch {epoch + 1:4d} Validation Accuracy: {100 * correct / len(val_ds)}%\")\n", "\n", " if best_acc < correct:\n", " print(f\"New Best Accuracy\")\n", " best_acc = correct\n", " torch.save({\n", " 'model_state_dict': model.state_dict()\n", " }, 'model.pt')\n", " tock = time.time()\n", " total_time += tock - tick\n", " print(f\"Total Time for Epoch {epoch + 1:4d}: {tock - tick:.6f}\")\n", "\n", "print(f\"Total Time: {total_time:.6f}\")\n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, { "cell_type": "markdown", "id": "08ecf6b4", "metadata": {}, "source": [ "# ResNext" ] }, { "cell_type": "code", "execution_count": null, "id": "9510fd7c", "metadata": {}, "outputs": [], "source": [ "import torch.nn as nn\n", "import torch\n", "import torch.nn.functional as F\n", "\n", "\n", "# Building Block for ResNeXt (type (a))\n", "class AggregatedBlock(nn.Module):\n", " cardinality = 8\n", "\n", " def __init__(self, in_channels, out_channels, stride=1):\n", " super(AggregatedBlock, self).__init__()\n", "\n", " assert out_channels % self.cardinality == 0\n", " mid_channels = out_channels // self.cardinality\n", "\n", " self.layer = []\n", " for _ in range(self.cardinality):\n", " self.layer.append(nn.Sequential(\n", " nn.Conv2d(in_channels, mid_channels, kernel_size=1, stride=1, bias=False),\n", " nn.BatchNorm2d(mid_channels),\n", " nn.ReLU(),\n", " nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=stride, padding=1, bias=False),\n", " nn.BatchNorm2d(mid_channels),\n", " nn.ReLU(),\n", " nn.Conv2d(mid_channels, out_channels, kernel_size=1, stride=1, bias=False),\n", " nn.BatchNorm2d(out_channels)\n", " ))\n", " self.layer = nn.ModuleList(self.layer)\n", "\n", " if in_channels != out_channels or stride != 1:\n", " self.downsample = nn.Sequential(\n", " nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),\n", " nn.BatchNorm2d(out_channels)\n", " )\n", " else:\n", " self.downsample = None\n", "\n", " def forward(self, x):\n", " out = sum([b(x) for b in self.layer])\n", "\n", " residual = x\n", " if self.downsample is not None:\n", " residual = self.downsample(residual)\n", "\n", " return F.relu(out + residual)\n", "\n", "\n", "# Building Block for ResNeXt (type (b))\n", "class InceptionBlock(nn.Module):\n", " cardinality = 8\n", "\n", " def __init__(self, in_channels, out_channels, stride=1):\n", " super(InceptionBlock, self).__init__()\n", "\n", " assert out_channels % self.cardinality == 0\n", " mid_channels = out_channels // self.cardinality\n", "\n", " self.layer = []\n", " for i in range(self.cardinality):\n", " self.layer.append(nn.Sequential(\n", " nn.Conv2d(in_channels, mid_channels, kernel_size=1, stride=1, bias=False),\n", " nn.BatchNorm2d(mid_channels),\n", " nn.ReLU(),\n", " nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=stride, padding=1, bias=False),\n", " nn.BatchNorm2d(mid_channels),\n", " nn.ReLU()\n", " ))\n", " self.layer = nn.ModuleList(self.layer)\n", "\n", " self.tail = nn.Sequential(\n", " nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias=False),\n", " nn.BatchNorm2d(out_channels)\n", " )\n", "\n", " if in_channels != out_channels or stride != 1:\n", " self.downsample = nn.Sequential(\n", " nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),\n", " nn.BatchNorm2d(out_channels)\n", " )\n", " else:\n", " self.downsample = None\n", "\n", " def forward(self, x):\n", " out = [b(x) for b in self.layer]\n", " out = torch.cat(out, dim=1)\n", " out = self.tail(out)\n", "\n", " residual = x\n", " if self.downsample is not None:\n", " residual = self.downsample(residual)\n", "\n", " return F.relu(out + residual)\n", "\n", "\n", "# Building Block for ResNeXt (type (c))\n", "class GroupConvBlock(nn.Module):\n", " cardinality = 8\n", "\n", " def __init__(self, in_channels, out_channels, stride=1):\n", " super(GroupConvBlock, self).__init__()\n", "\n", " assert out_channels % self.cardinality == 0\n", "\n", " self.layer = nn.Sequential(\n", " nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),\n", " nn.BatchNorm2d(out_channels),\n", " nn.ReLU(),\n", " nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, groups=self.cardinality,\n", " padding=1, bias=False),\n", " nn.BatchNorm2d(out_channels),\n", " nn.ReLU(),\n", " nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias=False),\n", " nn.BatchNorm2d(out_channels)\n", " )\n", "\n", " if in_channels != out_channels or stride != 1:\n", " self.downsample = nn.Sequential(\n", " nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),\n", " nn.BatchNorm2d(out_channels)\n", " )\n", " else:\n", " self.downsample = None\n", "\n", " def forward(self, x):\n", " out = self.layer(x)\n", "\n", " residual = x\n", " if self.downsample is not None:\n", " residual = self.downsample(residual)\n", "\n", " return F.relu(out + residual)\n", "\n", "\n", "# This is an illustration of how groupconv can be implemented\n", "class GroupConvLayer(nn.Module):\n", " def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1, bias=False):\n", " super(GroupConvLayer, self).__init__()\n", "\n", " assert in_channels % groups == 0 and out_channels % groups == 0\n", "\n", " self.groups = groups\n", " self.layer = []\n", " self.width = in_channels // groups\n", " for i in range(groups):\n", " self.layer.append(nn.Conv2d(\n", " in_channels=in_channels // groups,\n", " out_channels=out_channels // groups,\n", " kernel_size=kernel_size,\n", " stride=stride,\n", " bias=bias\n", " ))\n", " self.layer = nn.ModuleList(self.layer)\n", "\n", " def forward(self, x):\n", " w = self.width\n", " out = [layer(x[:, i * w: (i + 1) * w]) for i, layer in enumerate(self.layers)]\n", "\n", " return torch.cat(out, dim=1)\n", "\n", "\n", "class ResNeXt(nn.Module):\n", " cardinality = 32\n", "\n", " def __init__(self, topology_type, num_blocks, num_classes=100):\n", " assert topology_type in ['a', 'b', 'c']\n", "\n", " super(ResNeXt, self).__init__()\n", " self.layer1 = nn.Sequential(\n", " nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False),\n", " nn.BatchNorm2d(16)\n", " )\n", "\n", " self.topology_type = topology_type\n", "\n", " if topology_type == 'a':\n", " block_type = AggregatedBlock\n", " elif topology_type == 'b':\n", " block_type = InceptionBlock\n", " elif topology_type == 'c':\n", " block_type = GroupConvBlock\n", "\n", " self.layer2 = [block_type(16, 16) for _ in range(num_blocks[0])]\n", " self.layer2 = nn.ModuleList(self.layer2)\n", "\n", " self.layer3 = [block_type(16, 32, 2)]\n", " self.layer3 += [block_type(32, 32) for _ in range(num_blocks[1] - 1)]\n", " self.layer3 = nn.ModuleList(self.layer3)\n", "\n", " self.layer4 = [block_type(32, 64, 2)]\n", " self.layer4 += [block_type(64, 64) for _ in range(num_blocks[2] - 1)]\n", " self.layer4 = nn.ModuleList(self.layer4)\n", "\n", " self.avgpool = nn.AvgPool2d(8)\n", " self.fc = nn.Linear(64, num_classes)\n", "\n", " self.relu = nn.ReLU(inplace=True)\n", "\n", " def forward(self, x):\n", " out = self.layer1(x)\n", " out = self.relu(out)\n", "\n", " for layer in self.layer2:\n", " out = layer(out)\n", "\n", " for layer in self.layer3:\n", " out = layer(out)\n", "\n", " for layer in self.layer4:\n", " out = layer(out)\n", "\n", " out = self.avgpool(out)\n", " out = torch.squeeze(out)\n", " out = self.fc(out)\n", "\n", " return out" ] }, { "cell_type": "markdown", "id": "cb98e838", "metadata": {}, "source": [ "# SENet" ] }, { "cell_type": "code", "execution_count": null, "id": "fc63a0cd", "metadata": {}, "outputs": [], "source": [ "\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "\n", "class SEModule(nn.Module):\n", " def __init__(self, channels, reduction=16):\n", " super(SEModule, self).__init__()\n", " self.globalAvgPool = nn.AdaptiveAvgPool2d(1)\n", " self.fc = nn.Sequential(\n", " nn.Linear(channels, channels // reduction),\n", " nn.ReLU(),\n", " nn.Linear(channels // reduction, channels),\n", " nn.Sigmoid()\n", " )\n", "\n", " def forward(self, x):\n", " out = torch.squeeze(self.globalAvgPool(x))\n", " out = self.fc(out).view(x.size()[0], x.size()[1], 1, 1)\n", "\n", " # both methods works\n", " # return x * out\n", " # return x * out.expand_as(x)\n", "\n", "\n", "class SEBlock(nn.Module):\n", " def __init__(self, in_channels, out_channels, stride=1, reduction=16):\n", " super(SEBlock, self).__init__()\n", " self.conv1 = nn.Sequential(\n", " nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),\n", " nn.BatchNorm2d(out_channels),\n", " nn.ReLU()\n", " )\n", "\n", " self.conv2 = nn.Sequential(\n", " nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),\n", " nn.BatchNorm2d(out_channels),\n", " nn.ReLU()\n", " )\n", "\n", " self.se = SEModule(out_channels, reduction=reduction)\n", "\n", " self.downsample = None\n", " if stride != 1 or in_channels != out_channels:\n", " self.downsample = nn.Sequential(\n", " nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),\n", " nn.BatchNorm2d(out_channels)\n", " )\n", "\n", " def forward(self, x):\n", " out = self.conv1(x)\n", " out = self.conv2(out)\n", " out = self.se(out)\n", "\n", " residual = x\n", " if self.downsample is not None:\n", " residual = self.downsample(residual)\n", "\n", " out += residual\n", "\n", " return F.relu(out)\n", "\n", "\n", "class SENet(nn.Module):\n", " def __init__(self, num_blocks, num_classes=10):\n", " super(SENet, self).__init__()\n", "\n", " self.layer1 = nn.Sequential(\n", " nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False),\n", " nn.BatchNorm2d(16),\n", " nn.ReLU()\n", " )\n", "\n", " self.layer2 = [SEBlock(16, 16) for _ in range(num_blocks[0])]\n", " self.layer2 = nn.ModuleList(self.layer2)\n", "\n", " self.layer3 = [SEBlock(16, 32, 2)]\n", " self.layer3 += [SEBlock(32, 32) for _ in range(num_blocks[1] - 1)]\n", " self.layer3 = nn.ModuleList(self.layer3)\n", "\n", " self.layer4 = [SEBlock(32, 64, 2)]\n", " self.layer4 += [SEBlock(64, 64) for _ in range(num_blocks[2] - 1)]\n", " self.layer4 = nn.ModuleList(self.layer4)\n", "\n", " self.globalAvgPool = nn.AvgPool2d(8)\n", " self.fc = nn.Linear(64, num_classes)\n", "\n", " def forward(self, x):\n", " out = self.layer1(x)\n", "\n", " for layer in self.layer2:\n", " out = layer(out)\n", "\n", " for layer in self.layer3:\n", " out = layer(out)\n", "\n", " for layer in self.layer4:\n", " out = layer(out)\n", "\n", " out = self.globalAvgPool(out)\n", " out = torch.squeeze(out)\n", " out = self.fc(out)\n", "\n", " return out" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.11" } }, "nbformat": 4, "nbformat_minor": 5 }